use super::Error;
use super::Result;
use super::Timeout;
use crate::auth;
use crate::message_builder::MarshalledMessage;
use crate::wire::marshal;
use crate::wire::unmarshal;
use std::time;
use std::os::unix::io::RawFd;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net::UnixStream;
use nix::cmsg_space;
use nix::sys::socket::{
self, connect, recvmsg, sendmsg, socket, ControlMessage, ControlMessageOwned, MsgFlags,
SockAddr, UnixAddr,
};
use nix::sys::uio::IoVec;
#[derive(Debug)]
pub struct SendConn {
stream: UnixStream,
header_buf: Vec<u8>,
serial_counter: u32,
}
pub struct RecvConn {
stream: UnixStream,
msg_buf_in: Vec<u8>,
cmsgs_in: Vec<ControlMessageOwned>,
}
pub struct DuplexConn {
pub send: SendConn,
pub recv: RecvConn,
}
impl RecvConn {
pub fn can_read_from_source(&self) -> nix::Result<bool> {
let mut fdset = nix::sys::select::FdSet::new();
let fd = self.stream.as_raw_fd();
fdset.insert(fd);
use nix::sys::time::TimeValLike;
let mut zero_timeout = nix::sys::time::TimeVal::microseconds(0);
nix::sys::select::select(None, Some(&mut fdset), None, None, Some(&mut zero_timeout))?;
Ok(fdset.contains(fd))
}
fn refill_buffer(&mut self, max_buffer_size: usize, timeout: Timeout) -> Result<()> {
let bytes_to_read = max_buffer_size - self.msg_buf_in.len();
const BUFSIZE: usize = 512;
let mut tmpbuf = [0u8; BUFSIZE];
let iovec = IoVec::from_mut_slice(&mut tmpbuf[..usize::min(bytes_to_read, BUFSIZE)]);
let mut cmsgspace = cmsg_space!([RawFd; 10]);
let flags = MsgFlags::empty();
let old_timeout = self.stream.read_timeout()?;
match timeout {
Timeout::Duration(d) => {
self.stream.set_read_timeout(Some(d))?;
}
Timeout::Infinite => {
self.stream.set_read_timeout(None)?;
}
Timeout::Nonblock => {
self.stream.set_nonblocking(true)?;
}
}
let msg = recvmsg(
self.stream.as_raw_fd(),
&[iovec],
Some(&mut cmsgspace),
flags,
)
.map_err(|e| match e.as_errno() {
Some(nix::errno::Errno::EAGAIN) => Error::TimedOut,
_ => Error::NixError(e),
});
self.stream.set_nonblocking(false)?;
self.stream.set_read_timeout(old_timeout)?;
let msg = msg?;
self.msg_buf_in
.extend(&mut tmpbuf[..msg.bytes].iter().copied());
self.cmsgs_in.extend(msg.cmsgs());
Ok(())
}
pub fn bytes_needed_for_current_message(&self) -> Result<usize> {
if self.msg_buf_in.len() < 16 {
return Ok(16);
}
let (_, header) = unmarshal::unmarshal_header(&self.msg_buf_in, 0)?;
let (_, header_fields_len) = crate::wire::util::parse_u32(
&self.msg_buf_in[unmarshal::HEADER_LEN..],
header.byteorder,
)?;
let complete_header_size = unmarshal::HEADER_LEN + header_fields_len as usize + 4;
let padding_between_header_and_body = 8 - ((complete_header_size) % 8);
let padding_between_header_and_body = if padding_between_header_and_body == 8 {
0
} else {
padding_between_header_and_body
};
let bytes_needed = complete_header_size as usize
+ padding_between_header_and_body
+ header.body_len as usize;
Ok(bytes_needed)
}
pub fn buffer_contains_whole_message(&self) -> Result<bool> {
if self.msg_buf_in.len() < 16 {
return Ok(false);
}
let bytes_needed = self.bytes_needed_for_current_message();
match bytes_needed {
Err(e) => {
if let Error::UnmarshalError(unmarshal::Error::NotEnoughBytes) = e {
Ok(false)
} else {
Err(e)
}
}
Ok(bytes_needed) => Ok(self.msg_buf_in.len() >= bytes_needed),
}
}
pub fn read_whole_message(&mut self, timeout: Timeout) -> Result<()> {
let start_time = time::Instant::now();
while !self.buffer_contains_whole_message()? {
self.refill_buffer(
self.bytes_needed_for_current_message()?,
super::calc_timeout_left(&start_time, timeout)?,
)?;
}
Ok(())
}
pub fn read_once(&mut self, timeout: Timeout) -> Result<()> {
self.refill_buffer(self.bytes_needed_for_current_message()?, timeout)?;
Ok(())
}
pub fn get_next_message(&mut self, timeout: Timeout) -> Result<MarshalledMessage> {
self.read_whole_message(timeout)?;
let (hdrbytes, header) = unmarshal::unmarshal_header(&self.msg_buf_in, 0)?;
let (dynhdrbytes, dynheader) =
unmarshal::unmarshal_dynamic_header(&header, &self.msg_buf_in, hdrbytes)?;
let (bytes_used, mut msg) = unmarshal::unmarshal_next_message(
&header,
dynheader,
&self.msg_buf_in,
hdrbytes + dynhdrbytes,
)?;
if self.msg_buf_in.len() != bytes_used + hdrbytes + dynhdrbytes {
return Err(Error::UnmarshalError(unmarshal::Error::NotAllBytesUsed));
}
self.msg_buf_in.clear();
for cmsg in &self.cmsgs_in {
match cmsg {
ControlMessageOwned::ScmRights(fds) => {
msg.body
.raw_fds
.extend(fds.iter().map(|fd| crate::wire::UnixFd::new(*fd)));
}
_ => {
eprintln!("Cmsg other than ScmRights: {:?}", cmsg);
}
}
}
self.cmsgs_in.clear();
Ok(msg)
}
}
impl SendConn {
pub fn alloc_serial(&mut self) -> u32 {
let serial = self.serial_counter;
self.serial_counter += 1;
serial
}
pub fn send_message<'a>(
&'a mut self,
msg: &'a MarshalledMessage,
) -> Result<SendMessageContext<'a>> {
let serial = if let Some(serial) = msg.dynheader.serial {
serial
} else {
let serial = self.serial_counter;
self.serial_counter += 1;
serial
};
self.header_buf.clear();
marshal::marshal(&msg, serial, &mut self.header_buf)?;
let ctx = SendMessageContext {
msg,
conn: self,
state: SendMessageState {
bytes_sent: 0,
serial,
},
};
Ok(ctx)
}
pub fn send_message_write_all(&mut self, msg: &MarshalledMessage) -> Result<u32> {
let ctx = self.send_message(msg)?;
ctx.write_all().map_err(force_finish_on_error)
}
}
pub fn force_finish_on_error<E>((s, e): (SendMessageContext<'_>, E)) -> E {
s.force_finish();
e
}
#[must_use = "Dropping this type is considered an error since it might leave the connection in an illdefined state if only some bytes of a message have been written"]
#[derive(Debug)]
pub struct SendMessageContext<'a> {
msg: &'a MarshalledMessage,
conn: &'a mut SendConn,
state: SendMessageState,
}
#[derive(Debug, Copy, Clone)]
pub struct SendMessageState {
bytes_sent: usize,
serial: u32,
}
impl Drop for SendMessageContext<'_> {
fn drop(&mut self) {
if self.state.bytes_sent != 0 && !self.all_bytes_written() {
panic!("You dropped a SendMessageContext that only partially sent the message! This is not ok since that leaves the connection in an ill defined state. Use one of the consuming functions!");
} else {
}
}
}
impl SendMessageContext<'_> {
pub fn serial(&self) -> u32 {
self.state.serial
}
pub fn resume<'a>(
conn: &'a mut SendConn,
msg: &'a MarshalledMessage,
progress: SendMessageState,
) -> SendMessageContext<'a> {
SendMessageContext {
conn,
msg,
state: progress,
}
}
pub fn into_progress(self) -> SendMessageState {
let progress = self.state;
Self::force_finish(self);
progress
}
fn finish_if_ok<O, E>(
self,
res: std::result::Result<O, E>,
) -> std::result::Result<O, (Self, E)> {
match res {
Ok(o) => {
std::mem::drop(self);
Ok(o)
}
Err(e) => Err((self, e)),
}
}
pub fn force_finish(self) {
std::mem::forget(self)
}
pub fn write(mut self, timeout: Timeout) -> std::result::Result<u32, (Self, super::Error)> {
let start_time = std::time::Instant::now();
let res = loop {
let iteration_timeout = super::calc_timeout_left(&start_time, timeout);
let iteration_timeout = match iteration_timeout {
Err(e) => break Err(e),
Ok(t) => t,
};
match self.write_once(iteration_timeout) {
Err(e) => break Err(e),
Ok(t) => t,
};
if self.all_bytes_written() {
break Ok(self.state.serial);
}
};
self.finish_if_ok(res)
}
pub fn write_all(self) -> std::result::Result<u32, (Self, super::Error)> {
self.write(Timeout::Infinite)
}
pub fn bytes_total(&self) -> usize {
self.conn.header_buf.len() + self.msg.get_buf().len()
}
pub fn all_bytes_written(&self) -> bool {
self.state.bytes_sent == self.bytes_total()
}
pub fn write_once(&mut self, timeout: Timeout) -> Result<usize> {
let header_bytes_sent = usize::min(self.state.bytes_sent, self.conn.header_buf.len());
let header_slice_to_send = &self.conn.header_buf[header_bytes_sent..];
let body_bytes_sent = self.state.bytes_sent - header_bytes_sent;
let body_slice_to_send = &self.msg.get_buf()[body_bytes_sent..];
let iov = [
IoVec::from_slice(header_slice_to_send),
IoVec::from_slice(body_slice_to_send),
];
let flags = MsgFlags::empty();
let old_timeout = self.conn.stream.write_timeout()?;
match timeout {
Timeout::Duration(d) => {
self.conn.stream.set_write_timeout(Some(d))?;
}
Timeout::Infinite => {
self.conn.stream.set_write_timeout(None)?;
}
Timeout::Nonblock => {
self.conn.stream.set_nonblocking(true)?;
}
}
let raw_fds = if self.state.bytes_sent == 0 {
self.msg
.body
.raw_fds
.iter()
.map(|fd| fd.get_raw_fd())
.flatten()
.collect::<Vec<RawFd>>()
} else {
vec![]
};
let bytes_sent = sendmsg(
self.conn.stream.as_raw_fd(),
&iov,
&[ControlMessage::ScmRights(&raw_fds)],
flags,
None,
);
self.conn.stream.set_write_timeout(old_timeout)?;
self.conn.stream.set_nonblocking(false)?;
let bytes_sent = bytes_sent?;
self.state.bytes_sent += bytes_sent;
Ok(bytes_sent)
}
}
impl DuplexConn {
pub fn connect_to_bus(addr: UnixAddr, with_unix_fd: bool) -> super::Result<DuplexConn> {
let sock = socket(
socket::AddressFamily::Unix,
socket::SockType::Stream,
socket::SockFlag::empty(),
None,
)?;
let sock_addr = SockAddr::Unix(addr);
connect(sock, &sock_addr)?;
let mut stream = unsafe { UnixStream::from_raw_fd(sock) };
match auth::do_auth(&mut stream)? {
auth::AuthResult::Ok => {}
auth::AuthResult::Rejected => return Err(Error::AuthFailed),
}
if with_unix_fd {
match auth::negotiate_unix_fds(&mut stream)? {
auth::AuthResult::Ok => {}
auth::AuthResult::Rejected => return Err(Error::UnixFdNegotiationFailed),
}
}
auth::send_begin(&mut stream)?;
Ok(DuplexConn {
send: SendConn {
stream: stream.try_clone()?,
header_buf: Vec::new(),
serial_counter: 1,
},
recv: RecvConn {
msg_buf_in: Vec::new(),
cmsgs_in: Vec::new(),
stream,
},
})
}
pub fn send_hello(&mut self, timeout: crate::connection::Timeout) -> super::Result<String> {
let start_time = time::Instant::now();
let hello = crate::standard_messages::hello();
let serial = self
.send
.send_message(&hello)?
.write(super::calc_timeout_left(&start_time, timeout)?)
.map_err(|(ctx, e)| {
ctx.force_finish();
e
})?;
let resp = self
.recv
.get_next_message(super::calc_timeout_left(&start_time, timeout)?)?;
if resp.dynheader.response_serial != Some(serial) {
return Err(super::Error::AuthFailed);
}
let unique_name = resp.body.parser().get::<String>()?;
Ok(unique_name)
}
}
impl AsRawFd for SendConn {
fn as_raw_fd(&self) -> RawFd {
self.stream.as_raw_fd()
}
}
impl AsRawFd for RecvConn {
fn as_raw_fd(&self) -> RawFd {
self.stream.as_raw_fd()
}
}
impl AsRawFd for DuplexConn {
fn as_raw_fd(&self) -> RawFd {
self.recv.stream.as_raw_fd()
}
}