// Syd: rock-solid application kernel
// src/kernel/net/accept.rs: accept(2) and accept4(2) handler
//
// Copyright (c) 2023, 2024, 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

use std::os::fd::OwnedFd;

use libseccomp::ScmpNotifResp;
use nix::{
    errno::Errno,
    sys::socket::{SockFlag, SockaddrLike, SockaddrStorage},
};

use crate::{
    compat::getsockdomain,
    confine::op2errno,
    cookie::safe_accept4,
    fs::{get_nonblock, has_recv_timeout},
    hook::UNotifyEventRequest,
    kernel::net::sandbox_addr,
    sandbox::Capability,
};

pub(crate) fn handle_accept(
    fd: OwnedFd,
    request: &UNotifyEventRequest,
    args: &[u64; 6],
    op: u8,
) -> Result<ScmpNotifResp, Errno> {
    // Determine the socket family.
    if !matches!(
        getsockdomain(&fd).map_err(|_| op2errno(op))?,
        libc::AF_INET6 | libc::AF_INET
    ) {
        // Not an IPv{4,6} socket, continue system call.
        //
        // SAFETY: No pointer-dereference in access check.
        return unsafe { Ok(request.continue_syscall()) };
    }

    // Determine address length if specified.
    let addrlen = if args[2] != 0 {
        const SIZEOF_SOCKLEN_T: usize = std::mem::size_of::<libc::socklen_t>();
        let mut buf = [0u8; SIZEOF_SOCKLEN_T];
        if request.read_mem(&mut buf, args[2])? == SIZEOF_SOCKLEN_T {
            // libc defines socklen_t as u32,
            // however we should check for negative values
            // and return EINVAL as necessary.
            let len = i32::from_ne_bytes(buf);
            let len = libc::socklen_t::try_from(len).or(Err(Errno::EINVAL))?;
            if args[1] == 0 {
                // address length is positive however address is NULL,
                // return EFAULT.
                return Err(Errno::EFAULT);
            }
            Some(len)
        } else {
            // Invalid/short read, assume invalid address length.
            return Err(Errno::EINVAL);
        }
    } else {
        None
    };

    let sandbox = request.get_sandbox();
    let force_cloexec = sandbox.flags.force_cloexec();
    let force_rand_fd = sandbox.flags.force_rand_fd();
    drop(sandbox); // release read-lock.

    let mut flags = if op == 0x12 {
        // accept4
        SockFlag::from_bits(args[3].try_into().or(Err(Errno::EINVAL))?).ok_or(Errno::EINVAL)?
    } else {
        // accept
        SockFlag::empty()
    };
    let cloexec = force_cloexec || flags.contains(SockFlag::SOCK_CLOEXEC);
    flags.insert(SockFlag::SOCK_CLOEXEC);

    // Check whether we should block and ignore restarts.
    let (is_blocking, ignore_restart) = if !get_nonblock(&fd)? {
        let ignore_restart = has_recv_timeout(&fd)?;
        (true, ignore_restart)
    } else {
        (false, false)
    };

    // Do the accept call.
    let (fd, addr, addrlen_out) = do_accept4(fd, request, flags, is_blocking, ignore_restart)?;

    // Check the source address against the IP blocklist.
    // No port filtering is done here for simplicity and efficiency.
    let sandbox = request.get_sandbox();
    sandbox_addr(request, &sandbox, &addr, &None, op, Capability::empty())?;
    drop(sandbox); // release the read lock.

    // Write address buffer as necessary.
    if let Some(addrlen) = addrlen {
        // Create a byte slice from the socket address pointer.
        // SAFETY:
        // 1. `addrlen_out` value is returned by the host Linux kernel
        //    and is therefore trusted.
        // 2. `ptr` is a valid pointer to memory of at least
        //    `addrlen_out` bytes, as it is provided by the
        //    `SockaddrStorage` instance.
        // 3. The `SockaddrStorage` type ensures that the memory pointed
        //    to by `ptr` is valid and properly aligned.
        let buf = unsafe { std::slice::from_raw_parts(addr.as_ptr().cast(), addrlen_out as usize) };

        // Write the truncated socket address into memory.
        // SAFETY: We truncate late to avoid potential UB in
        // std::slice::slice_from_raw_parts().
        let len = addrlen_out.min(addrlen) as usize;
        request.write_mem(&buf[..len], args[1])?;

        // Convert `addrlen_out` into a vector of bytes.
        // SAFETY: This must be socklen_t and _not_ usize!
        let buf = addrlen_out.to_ne_bytes();

        // Write `addrlen_out` into memory.
        request.write_mem(&buf, args[2])?;
    }

    // Send the fd and return.
    request.send_fd(fd, cloexec, force_rand_fd)
}

fn do_accept4(
    fd: OwnedFd,
    request: &UNotifyEventRequest,
    flags: SockFlag,
    is_blocking: bool,
    ignore_restart: bool,
) -> Result<(OwnedFd, SockaddrStorage, libc::socklen_t), Errno> {
    // Allocate storage for the address.
    let mut addr: [u8; std::mem::size_of::<SockaddrStorage>()] =
        [0u8; std::mem::size_of::<SockaddrStorage>()];
    #[allow(clippy::cast_possible_truncation)]
    let mut len = std::mem::size_of::<SockaddrStorage>() as libc::socklen_t;

    // Cast the storage buffer to a sockaddr pointer.
    #[allow(clippy::cast_ptr_alignment)]
    let ptr = addr.as_mut_ptr() as *mut libc::sockaddr;

    // SAFETY: Record blocking call so it can get invalidated.
    if is_blocking {
        request
            .cache
            .add_sys_block(request.scmpreq, ignore_restart)?;
    };

    // Make the accept4(2) call.
    let result = safe_accept4(fd, ptr, &raw mut len, flags);

    // Remove invalidation record unless interrupted.
    if is_blocking {
        request
            .cache
            .del_sys_block(request.scmpreq.id, matches!(result, Err(Errno::EINTR)))?;
    }

    // Check for accept4 errors after invalidation.
    let fd = result?;

    // SAFETY:
    // Convert the raw address into a SockaddrStorage structure.
    // accept4 returned success so the pointer is valid.
    let addr = unsafe { SockaddrStorage::from_raw(ptr, Some(len)) }.ok_or(Errno::EINVAL)?;

    Ok((fd, addr, len))
}
