// yt - A fully featured command line YouTube client
//
// Copyright (C) 2024 Benedikt Peetz <benedikt.peetz@b-peetz.de>
// SPDX-License-Identifier: GPL-3.0-or-later
//
// This file is part of Yt.
//
// You should have received a copy of the License along with this program.
// If not, see <https://www.gnu.org/licenses/gpl-3.0.txt>.

use super::*;

use std::alloc::{self, Layout};
use std::marker::PhantomData;
use std::mem;
use std::os::raw as ctype;
use std::panic;
use std::panic::RefUnwindSafe;
use std::slice;
use std::sync::{atomic::Ordering, Mutex};

impl Mpv {
    /// Create a context with which custom protocols can be registered.
    ///
    /// # Panics
    /// Panics if a context already exists
    pub fn create_protocol_context<T, U>(&self) -> ProtocolContext<T, U>
    where
        T: RefUnwindSafe,
        U: RefUnwindSafe,
    {
        match self.protocols_guard.compare_exchange(
            false,
            true,
            Ordering::AcqRel,
            Ordering::Acquire,
        ) {
            Ok(_) => ProtocolContext::new(self.ctx, PhantomData::<&Self>),
            Err(_) => panic!("A protocol context already exists"),
        }
    }
}

/// Return a persistent `T` that is passed to all other `Stream*` functions, panic on errors.
pub type StreamOpen<T, U> = fn(&mut U, &str) -> T;
/// Do any necessary cleanup.
pub type StreamClose<T> = fn(Box<T>);
/// Seek to the given offset. Return the new offset, or either `MpvError::Generic` if seeking
/// failed or panic.
pub type StreamSeek<T> = fn(&mut T, i64) -> i64;
/// Target buffer with fixed capacity.
/// Return either the number of read bytes, `0` on EOF, or either `-1` or panic on error.
pub type StreamRead<T> = fn(&mut T, &mut [ctype::c_char]) -> i64;
/// Return the total size of the stream in bytes. Panic on error.
pub type StreamSize<T> = fn(&mut T) -> i64;

unsafe extern "C" fn open_wrapper<T, U>(
    user_data: *mut ctype::c_void,
    uri: *mut ctype::c_char,
    info: *mut libmpv2_sys::mpv_stream_cb_info,
) -> ctype::c_int
where
    T: RefUnwindSafe,
    U: RefUnwindSafe,
{
    let data = user_data as *mut ProtocolData<T, U>;

    (*info).cookie = user_data;
    (*info).read_fn = Some(read_wrapper::<T, U>);
    (*info).seek_fn = Some(seek_wrapper::<T, U>);
    (*info).size_fn = Some(size_wrapper::<T, U>);
    (*info).close_fn = Some(close_wrapper::<T, U>);

    let ret = panic::catch_unwind(|| {
        let uri = mpv_cstr_to_str!(uri as *const _).unwrap();
        ptr::write(
            (*data).cookie,
            ((*data).open_fn)(&mut (*data).user_data, uri),
        );
    });

    if ret.is_ok() {
        0
    } else {
        mpv_error::Generic as _
    }
}

unsafe extern "C" fn read_wrapper<T, U>(
    cookie: *mut ctype::c_void,
    buf: *mut ctype::c_char,
    nbytes: u64,
) -> i64
where
    T: RefUnwindSafe,
    U: RefUnwindSafe,
{
    let data = cookie as *mut ProtocolData<T, U>;

    let ret = panic::catch_unwind(|| {
        let slice = slice::from_raw_parts_mut(buf, nbytes as _);
        ((*data).read_fn)(&mut *(*data).cookie, slice)
    });
    if let Ok(ret) = ret {
        ret
    } else {
        -1
    }
}

unsafe extern "C" fn seek_wrapper<T, U>(cookie: *mut ctype::c_void, offset: i64) -> i64
where
    T: RefUnwindSafe,
    U: RefUnwindSafe,
{
    let data = cookie as *mut ProtocolData<T, U>;

    if (*data).seek_fn.is_none() {
        return mpv_error::Unsupported as _;
    }

    let ret =
        panic::catch_unwind(|| (*(*data).seek_fn.as_ref().unwrap())(&mut *(*data).cookie, offset));
    if let Ok(ret) = ret {
        ret
    } else {
        mpv_error::Generic as _
    }
}

unsafe extern "C" fn size_wrapper<T, U>(cookie: *mut ctype::c_void) -> i64
where
    T: RefUnwindSafe,
    U: RefUnwindSafe,
{
    let data = cookie as *mut ProtocolData<T, U>;

    if (*data).size_fn.is_none() {
        return mpv_error::Unsupported as _;
    }

    let ret = panic::catch_unwind(|| (*(*data).size_fn.as_ref().unwrap())(&mut *(*data).cookie));
    if let Ok(ret) = ret {
        ret
    } else {
        mpv_error::Unsupported as _
    }
}

#[allow(unused_must_use)]
unsafe extern "C" fn close_wrapper<T, U>(cookie: *mut ctype::c_void)
where
    T: RefUnwindSafe,
    U: RefUnwindSafe,
{
    let data = Box::from_raw(cookie as *mut ProtocolData<T, U>);

    panic::catch_unwind(|| ((*data).close_fn)(Box::from_raw((*data).cookie)));
}

struct ProtocolData<T, U> {
    cookie: *mut T,
    user_data: U,

    open_fn: StreamOpen<T, U>,
    close_fn: StreamClose<T>,
    read_fn: StreamRead<T>,
    seek_fn: Option<StreamSeek<T>>,
    size_fn: Option<StreamSize<T>>,
}

/// This context holds state relevant to custom protocols.
/// It is created by calling `Mpv::create_protocol_context`.
pub struct ProtocolContext<'parent, T: RefUnwindSafe, U: RefUnwindSafe> {
    ctx: NonNull<libmpv2_sys::mpv_handle>,
    protocols: Mutex<Vec<Protocol<T, U>>>,
    _does_not_outlive: PhantomData<&'parent Mpv>,
}

unsafe impl<'parent, T: RefUnwindSafe, U: RefUnwindSafe> Send for ProtocolContext<'parent, T, U> {}
unsafe impl<'parent, T: RefUnwindSafe, U: RefUnwindSafe> Sync for ProtocolContext<'parent, T, U> {}

impl<'parent, T: RefUnwindSafe, U: RefUnwindSafe> ProtocolContext<'parent, T, U> {
    fn new(
        ctx: NonNull<libmpv2_sys::mpv_handle>,
        marker: PhantomData<&'parent Mpv>,
    ) -> ProtocolContext<'parent, T, U> {
        ProtocolContext {
            ctx,
            protocols: Mutex::new(Vec::new()),
            _does_not_outlive: marker,
        }
    }

    /// Register a custom `Protocol`. Once a protocol has been registered, it lives as long as
    /// `Mpv`.
    ///
    /// Returns `Error::Mpv(MpvError::InvalidParameter)` if a protocol with the same name has
    /// already been registered.
    pub fn register(&self, protocol: Protocol<T, U>) -> Result<()> {
        let mut protocols = self.protocols.lock().unwrap();
        protocol.register(self.ctx.as_ptr())?;
        protocols.push(protocol);
        Ok(())
    }
}

/// `Protocol` holds all state used by a custom protocol.
pub struct Protocol<T: Sized + RefUnwindSafe, U: RefUnwindSafe> {
    name: String,
    data: *mut ProtocolData<T, U>,
}

impl<T: RefUnwindSafe, U: RefUnwindSafe> Protocol<T, U> {
    /// `name` is the prefix of the protocol, e.g. `name://path`.
    ///
    /// `user_data` is data that will be passed to `open_fn`.
    ///
    /// # Safety
    /// Do not call libmpv functions in any supplied function.
    /// All panics of the provided functions are catched and can be used as generic error returns.
    pub unsafe fn new(
        name: String,
        user_data: U,
        open_fn: StreamOpen<T, U>,
        close_fn: StreamClose<T>,
        read_fn: StreamRead<T>,
        seek_fn: Option<StreamSeek<T>>,
        size_fn: Option<StreamSize<T>>,
    ) -> Protocol<T, U> {
        let c_layout = Layout::from_size_align(mem::size_of::<T>(), mem::align_of::<T>()).unwrap();
        let cookie = alloc::alloc(c_layout) as *mut T;
        let data = Box::into_raw(Box::new(ProtocolData {
            cookie,
            user_data,

            open_fn,
            close_fn,
            read_fn,
            seek_fn,
            size_fn,
        }));

        Protocol { name, data }
    }

    fn register(&self, ctx: *mut libmpv2_sys::mpv_handle) -> Result<()> {
        let name = CString::new(&self.name[..])?;
        unsafe {
            mpv_err(
                (),
                libmpv2_sys::mpv_stream_cb_add_ro(
                    ctx,
                    name.as_ptr(),
                    self.data as *mut _,
                    Some(open_wrapper::<T, U>),
                ),
            )
        }
    }
}