use super::Error;
use super::Result;
use super::Timeout;
use crate::auth;
use crate::message_builder::MarshalledMessage;
use crate::wire::marshal;
use crate::wire::unmarshal;
use std::time;
use std::os::unix::io::RawFd;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net::UnixStream;
use nix::cmsg_space;
use nix::sys::socket::{
    self, connect, recvmsg, sendmsg, socket, ControlMessage, ControlMessageOwned, MsgFlags,
    SockAddr, UnixAddr,
};
use nix::sys::uio::IoVec;
#[derive(Debug)]
pub struct SendConn {
    stream: UnixStream,
    header_buf: Vec<u8>,
    serial_counter: u32,
}
pub struct RecvConn {
    stream: UnixStream,
    msg_buf_in: Vec<u8>,
    cmsgs_in: Vec<ControlMessageOwned>,
}
pub struct DuplexConn {
    pub send: SendConn,
    pub recv: RecvConn,
}
impl RecvConn {
    pub fn can_read_from_source(&self) -> nix::Result<bool> {
        let mut fdset = nix::sys::select::FdSet::new();
        let fd = self.stream.as_raw_fd();
        fdset.insert(fd);
        use nix::sys::time::TimeValLike;
        let mut zero_timeout = nix::sys::time::TimeVal::microseconds(0);
        nix::sys::select::select(None, Some(&mut fdset), None, None, Some(&mut zero_timeout))?;
        Ok(fdset.contains(fd))
    }
    
    
    fn refill_buffer(&mut self, max_buffer_size: usize, timeout: Timeout) -> Result<()> {
        let bytes_to_read = max_buffer_size - self.msg_buf_in.len();
        const BUFSIZE: usize = 512;
        let mut tmpbuf = [0u8; BUFSIZE];
        let iovec = IoVec::from_mut_slice(&mut tmpbuf[..usize::min(bytes_to_read, BUFSIZE)]);
        let mut cmsgspace = cmsg_space!([RawFd; 10]);
        let flags = MsgFlags::empty();
        let old_timeout = self.stream.read_timeout()?;
        match timeout {
            Timeout::Duration(d) => {
                self.stream.set_read_timeout(Some(d))?;
            }
            Timeout::Infinite => {
                self.stream.set_read_timeout(None)?;
            }
            Timeout::Nonblock => {
                self.stream.set_nonblocking(true)?;
            }
        }
        let msg = recvmsg(
            self.stream.as_raw_fd(),
            &[iovec],
            Some(&mut cmsgspace),
            flags,
        )
        .map_err(|e| match e.as_errno() {
            Some(nix::errno::Errno::EAGAIN) => Error::TimedOut,
            _ => Error::NixError(e),
        });
        self.stream.set_nonblocking(false)?;
        self.stream.set_read_timeout(old_timeout)?;
        let msg = msg?;
        self.msg_buf_in
            .extend(&mut tmpbuf[..msg.bytes].iter().copied());
        self.cmsgs_in.extend(msg.cmsgs());
        Ok(())
    }
    pub fn bytes_needed_for_current_message(&self) -> Result<usize> {
        if self.msg_buf_in.len() < 16 {
            return Ok(16);
        }
        let (_, header) = unmarshal::unmarshal_header(&self.msg_buf_in, 0)?;
        let (_, header_fields_len) = crate::wire::util::parse_u32(
            &self.msg_buf_in[unmarshal::HEADER_LEN..],
            header.byteorder,
        )?;
        let complete_header_size = unmarshal::HEADER_LEN + header_fields_len as usize + 4; 
        let padding_between_header_and_body = 8 - ((complete_header_size) % 8);
        let padding_between_header_and_body = if padding_between_header_and_body == 8 {
            0
        } else {
            padding_between_header_and_body
        };
        let bytes_needed = complete_header_size as usize
            + padding_between_header_and_body
            + header.body_len as usize;
        Ok(bytes_needed)
    }
    
    pub fn buffer_contains_whole_message(&self) -> Result<bool> {
        if self.msg_buf_in.len() < 16 {
            return Ok(false);
        }
        let bytes_needed = self.bytes_needed_for_current_message();
        match bytes_needed {
            Err(e) => {
                if let Error::UnmarshalError(unmarshal::Error::NotEnoughBytes) = e {
                    Ok(false)
                } else {
                    Err(e)
                }
            }
            Ok(bytes_needed) => Ok(self.msg_buf_in.len() >= bytes_needed),
        }
    }
    
    pub fn read_whole_message(&mut self, timeout: Timeout) -> Result<()> {
        
        
        
        let start_time = time::Instant::now();
        while !self.buffer_contains_whole_message()? {
            self.refill_buffer(
                self.bytes_needed_for_current_message()?,
                super::calc_timeout_left(&start_time, timeout)?,
            )?;
        }
        Ok(())
    }
    
    pub fn read_once(&mut self, timeout: Timeout) -> Result<()> {
        self.refill_buffer(self.bytes_needed_for_current_message()?, timeout)?;
        Ok(())
    }
    
    pub fn get_next_message(&mut self, timeout: Timeout) -> Result<MarshalledMessage> {
        self.read_whole_message(timeout)?;
        let (hdrbytes, header) = unmarshal::unmarshal_header(&self.msg_buf_in, 0)?;
        let (dynhdrbytes, dynheader) =
            unmarshal::unmarshal_dynamic_header(&header, &self.msg_buf_in, hdrbytes)?;
        let (bytes_used, mut msg) = unmarshal::unmarshal_next_message(
            &header,
            dynheader,
            &self.msg_buf_in,
            hdrbytes + dynhdrbytes,
        )?;
        if self.msg_buf_in.len() != bytes_used + hdrbytes + dynhdrbytes {
            return Err(Error::UnmarshalError(unmarshal::Error::NotAllBytesUsed));
        }
        self.msg_buf_in.clear();
        for cmsg in &self.cmsgs_in {
            match cmsg {
                ControlMessageOwned::ScmRights(fds) => {
                    msg.body
                        .raw_fds
                        .extend(fds.iter().map(|fd| crate::wire::UnixFd::new(*fd)));
                }
                _ => {
                    
                    eprintln!("Cmsg other than ScmRights: {:?}", cmsg);
                }
            }
        }
        self.cmsgs_in.clear();
        Ok(msg)
    }
}
impl SendConn {
    
    pub fn alloc_serial(&mut self) -> u32 {
        let serial = self.serial_counter;
        self.serial_counter += 1;
        serial
    }
    
    pub fn send_message<'a>(
        &'a mut self,
        msg: &'a MarshalledMessage,
    ) -> Result<SendMessageContext<'a>> {
        let serial = if let Some(serial) = msg.dynheader.serial {
            serial
        } else {
            let serial = self.serial_counter;
            self.serial_counter += 1;
            serial
        };
        
        self.header_buf.clear();
        marshal::marshal(&msg, serial, &mut self.header_buf)?;
        let ctx = SendMessageContext {
            msg,
            conn: self,
            state: SendMessageState {
                bytes_sent: 0,
                serial,
            },
        };
        Ok(ctx)
    }
    
    pub fn send_message_write_all(&mut self, msg: &MarshalledMessage) -> Result<u32> {
        let ctx = self.send_message(msg)?;
        ctx.write_all().map_err(force_finish_on_error)
    }
}
pub fn force_finish_on_error<E>((s, e): (SendMessageContext<'_>, E)) -> E {
    s.force_finish();
    e
}
#[must_use = "Dropping this type is considered an error since it might leave the connection in an illdefined state if only some bytes of a message have been written"]
#[derive(Debug)]
pub struct SendMessageContext<'a> {
    msg: &'a MarshalledMessage,
    conn: &'a mut SendConn,
    state: SendMessageState,
}
#[derive(Debug, Copy, Clone)]
pub struct SendMessageState {
    bytes_sent: usize,
    serial: u32,
}
impl Drop for SendMessageContext<'_> {
    fn drop(&mut self) {
        if self.state.bytes_sent != 0 && !self.all_bytes_written() {
            panic!("You dropped a SendMessageContext that only partially sent the message! This is not ok since that leaves the connection in an ill defined state. Use one of the consuming functions!");
        } else {
            
        }
    }
}
impl SendMessageContext<'_> {
    pub fn serial(&self) -> u32 {
        self.state.serial
    }
    
    
    pub fn resume<'a>(
        conn: &'a mut SendConn,
        msg: &'a MarshalledMessage,
        progress: SendMessageState,
    ) -> SendMessageContext<'a> {
        SendMessageContext {
            conn,
            msg,
            state: progress,
        }
    }
    
    
    
    pub fn into_progress(self) -> SendMessageState {
        let progress = self.state;
        Self::force_finish(self);
        progress
    }
    
    fn finish_if_ok<O, E>(
        self,
        res: std::result::Result<O, E>,
    ) -> std::result::Result<O, (Self, E)> {
        match res {
            Ok(o) => {
                
                std::mem::drop(self);
                Ok(o)
            }
            Err(e) => Err((self, e)),
        }
    }
    
    
    
    pub fn force_finish(self) {
        std::mem::forget(self)
    }
    
    
    pub fn write(mut self, timeout: Timeout) -> std::result::Result<u32, (Self, super::Error)> {
        let start_time = std::time::Instant::now();
        
        let res = loop {
            let iteration_timeout = super::calc_timeout_left(&start_time, timeout);
            let iteration_timeout = match iteration_timeout {
                Err(e) => break Err(e),
                Ok(t) => t,
            };
            match self.write_once(iteration_timeout) {
                Err(e) => break Err(e),
                Ok(t) => t,
            };
            if self.all_bytes_written() {
                break Ok(self.state.serial);
            }
        };
        
        self.finish_if_ok(res)
    }
    
    pub fn write_all(self) -> std::result::Result<u32, (Self, super::Error)> {
        self.write(Timeout::Infinite)
    }
    
    pub fn bytes_total(&self) -> usize {
        self.conn.header_buf.len() + self.msg.get_buf().len()
    }
    
    pub fn all_bytes_written(&self) -> bool {
        self.state.bytes_sent == self.bytes_total()
    }
    
    
    pub fn write_once(&mut self, timeout: Timeout) -> Result<usize> {
        
        
        let header_bytes_sent = usize::min(self.state.bytes_sent, self.conn.header_buf.len());
        let header_slice_to_send = &self.conn.header_buf[header_bytes_sent..];
        let body_bytes_sent = self.state.bytes_sent - header_bytes_sent;
        let body_slice_to_send = &self.msg.get_buf()[body_bytes_sent..];
        let iov = [
            IoVec::from_slice(header_slice_to_send),
            IoVec::from_slice(body_slice_to_send),
        ];
        let flags = MsgFlags::empty();
        let old_timeout = self.conn.stream.write_timeout()?;
        match timeout {
            Timeout::Duration(d) => {
                self.conn.stream.set_write_timeout(Some(d))?;
            }
            Timeout::Infinite => {
                self.conn.stream.set_write_timeout(None)?;
            }
            Timeout::Nonblock => {
                self.conn.stream.set_nonblocking(true)?;
            }
        }
        
        
        let raw_fds = if self.state.bytes_sent == 0 {
            self.msg
                .body
                .raw_fds
                .iter()
                .map(|fd| fd.get_raw_fd())
                .flatten()
                .collect::<Vec<RawFd>>()
        } else {
            vec![]
        };
        let bytes_sent = sendmsg(
            self.conn.stream.as_raw_fd(),
            &iov,
            &[ControlMessage::ScmRights(&raw_fds)],
            flags,
            None,
        );
        self.conn.stream.set_write_timeout(old_timeout)?;
        self.conn.stream.set_nonblocking(false)?;
        let bytes_sent = bytes_sent?;
        self.state.bytes_sent += bytes_sent;
        Ok(bytes_sent)
    }
}
impl DuplexConn {
    
    pub fn connect_to_bus(addr: UnixAddr, with_unix_fd: bool) -> super::Result<DuplexConn> {
        let sock = socket(
            socket::AddressFamily::Unix,
            socket::SockType::Stream,
            socket::SockFlag::empty(),
            None,
        )?;
        let sock_addr = SockAddr::Unix(addr);
        connect(sock, &sock_addr)?;
        let mut stream = unsafe { UnixStream::from_raw_fd(sock) };
        match auth::do_auth(&mut stream)? {
            auth::AuthResult::Ok => {}
            auth::AuthResult::Rejected => return Err(Error::AuthFailed),
        }
        if with_unix_fd {
            match auth::negotiate_unix_fds(&mut stream)? {
                auth::AuthResult::Ok => {}
                auth::AuthResult::Rejected => return Err(Error::UnixFdNegotiationFailed),
            }
        }
        auth::send_begin(&mut stream)?;
        Ok(DuplexConn {
            send: SendConn {
                stream: stream.try_clone()?,
                header_buf: Vec::new(),
                serial_counter: 1,
            },
            recv: RecvConn {
                msg_buf_in: Vec::new(),
                cmsgs_in: Vec::new(),
                stream,
            },
        })
    }
    
    pub fn send_hello(&mut self, timeout: crate::connection::Timeout) -> super::Result<String> {
        let start_time = time::Instant::now();
        let hello = crate::standard_messages::hello();
        let serial = self
            .send
            .send_message(&hello)?
            .write(super::calc_timeout_left(&start_time, timeout)?)
            .map_err(|(ctx, e)| {
                ctx.force_finish();
                e
            })?;
        let resp = self
            .recv
            .get_next_message(super::calc_timeout_left(&start_time, timeout)?)?;
        if resp.dynheader.response_serial != Some(serial) {
            return Err(super::Error::AuthFailed);
        }
        let unique_name = resp.body.parser().get::<String>()?;
        Ok(unique_name)
    }
}
impl AsRawFd for SendConn {
    
    
    fn as_raw_fd(&self) -> RawFd {
        self.stream.as_raw_fd()
    }
}
impl AsRawFd for RecvConn {
    
    
    fn as_raw_fd(&self) -> RawFd {
        self.stream.as_raw_fd()
    }
}
impl AsRawFd for DuplexConn {
    
    
    fn as_raw_fd(&self) -> RawFd {
        self.recv.stream.as_raw_fd()
    }
}