diff --git a/litebox/src/net/socket_channel.rs b/litebox/src/net/socket_channel.rs index 9505203ed..766ab9f8b 100644 --- a/litebox/src/net/socket_channel.rs +++ b/litebox/src/net/socket_channel.rs @@ -59,9 +59,6 @@ use crate::sync::{Mutex, RawSyncPrimitivesProvider}; use crate::{ event::{Events, IOPollable, observer::Observer, polling::Pollee}, net::ReceiveFlags, -}; -use crate::{ - net::errors::{ReceiveError, SendError}, platform::TimeProvider, }; @@ -120,8 +117,8 @@ impl SocketAsyncErrorState { #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u32)] pub enum SocketState { - /// Socket is closed or in initial state - Closed = 0, + /// Socket is in initial state. + Initial = 0, /// Socket is connecting (TCP SYN sent) Connecting = 1, /// Socket is connected and ready for data transfer @@ -130,15 +127,46 @@ pub enum SocketState { Listening = 3, /// Socket encountered an error Error = 4, + /// Socket is closed. + Closed = 5, +} + +/// Possible errors from [`NetworkProxy::try_read`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChannelReadError { + /// The local read side has been shut down. + ReadShutdown, + /// The stream has not reached a connected state. + NotConnected, + /// The stream is closed. + ConnectionClosed, +} + +/// Possible errors from [`NetworkProxy::try_write`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ChannelWriteError { + /// The local write side has been shut down. + WriteShutdown, + /// The stream has not reached a connected state. + NotConnected, + /// The stream is closed. + ConnectionClosed, + /// The destination address cannot be used. + Unaddressable, + /// The transmit buffer is full. + BufferFull, + /// A destination address is required but was not provided. + DestinationAddressRequired, } impl From for SocketState { fn from(v: u32) -> Self { match v { - 0 => SocketState::Closed, + 0 => SocketState::Initial, 1 => SocketState::Connecting, 2 => SocketState::Connected, 3 => SocketState::Listening, + 5 => SocketState::Closed, _ => SocketState::Error, } } @@ -175,14 +203,10 @@ impl NetworkProxy } /// Set the async socket error. - pub(super) fn set_async_error(&self, error: super::errors::SocketAsyncError) { + pub fn set_async_error(&self, error: super::errors::SocketAsyncError) { match self { - NetworkProxy::Stream(channel) => { - channel.set_async_error(error); - } - NetworkProxy::Datagram(channel) => { - channel.set_async_error(error); - } + NetworkProxy::Stream(channel) => channel.set_async_error(error), + NetworkProxy::Datagram(channel) => channel.set_async_error(error), NetworkProxy::Raw => {} } } @@ -205,7 +229,7 @@ impl NetworkProxy buf: &mut [u8], flags: super::ReceiveFlags, source_addr: Option<&mut Option>, - ) -> Result { + ) -> Result { match self { NetworkProxy::Stream(channel) => channel.try_read(buf, flags, source_addr), NetworkProxy::Datagram(channel) => channel.try_read(buf, flags, source_addr), @@ -222,14 +246,14 @@ impl NetworkProxy buf: &[u8], flags: super::SendFlags, destination: Option, - ) -> Result { + ) -> Result { if !flags.is_empty() { unimplemented!() } if let Some(addr) = destination && (addr.port() == 0 || addr.ip().is_unspecified()) { - return Err(SendError::Unaddressable); + return Err(ChannelWriteError::Unaddressable); } match self { NetworkProxy::Stream(channel) => channel.try_write(buf), @@ -331,7 +355,7 @@ impl StreamChannelInner StreamSocketChannel>, - ) -> Result { + ) -> Result { if self.inner.read_shutdown.load(Ordering::Acquire) { - return Err(ReceiveError::SocketInInvalidState); - } - - match self.inner.state() { - SocketState::Connected => {} - _ => return Err(ReceiveError::SocketInInvalidState), + return Err(ChannelReadError::ReadShutdown); } let mut rx_cons = self.inner.rx_cons.lock(); @@ -418,7 +437,15 @@ impl StreamSocketChannel 0 { + return Ok(n); + } + match self.inner.state() { + SocketState::Connected => Ok(0), + SocketState::Closed | SocketState::Error => Err(ChannelReadError::ConnectionClosed), + _ => Err(ChannelReadError::NotConnected), + } } /// Write data to the socket from the provided buffer. @@ -428,14 +455,17 @@ impl StreamSocketChannel Result { + pub fn try_write(&self, buf: &[u8]) -> Result { if self.inner.write_shutdown.load(Ordering::Acquire) { - return Err(SendError::SocketInInvalidState); + return Err(ChannelWriteError::WriteShutdown); } match self.state() { SocketState::Connected => {} - _ => return Err(SendError::SocketInInvalidState), + SocketState::Closed | SocketState::Error => { + return Err(ChannelWriteError::ConnectionClosed); + } + _ => return Err(ChannelWriteError::NotConnected), } let mut tx_prod = self.inner.tx_prod.lock(); @@ -446,7 +476,7 @@ impl StreamSocketChannel IOPollable } match self.inner.state() { - SocketState::Closed => events |= Events::HUP | Events::OUT, + SocketState::Initial | SocketState::Closed => events |= Events::HUP | Events::OUT, SocketState::Error => events |= Events::ERR | Events::OUT, SocketState::Connected if self.is_writable() => events |= Events::OUT, _ => {} @@ -791,7 +821,7 @@ impl DatagramSocketChannel

>, - ) -> Result { + ) -> Result { let mut rx_cons = self.inner.rx_cons.lock(); if let Some(msg) = rx_cons.try_pop() { @@ -814,10 +844,14 @@ impl DatagramSocketChannel

) -> Result { + pub fn send_to( + &self, + data: &[u8], + addr: Option, + ) -> Result { if addr.is_none() && !self.inner.is_connected.load(Ordering::Acquire) { // No destination specified and socket is not connected - return Err(SendError::DestinationAddressRequired); + return Err(ChannelWriteError::DestinationAddressRequired); } let size = data.len(); @@ -832,7 +866,7 @@ impl DatagramSocketChannel

Err(SendError::BufferFull), + Err(_) => Err(ChannelWriteError::BufferFull), } } @@ -1006,8 +1040,7 @@ mod tests { fn stream_channel_initial_state() { let channel: StreamSocketChannel = StreamSocketChannel::new(); - // Initial state should be Closed - assert_eq!(channel.state(), SocketState::Closed); + assert_eq!(channel.state(), SocketState::Initial); // Should not be readable initially assert!(!channel.is_readable()); @@ -1078,12 +1111,12 @@ mod tests { // Try to read while not connected let mut buf = [0u8; 32]; let result = channel.try_read(&mut buf, super::super::ReceiveFlags::empty(), None); - assert!(matches!(result, Err(ReceiveError::SocketInInvalidState))); + assert!(matches!(result, Err(ChannelReadError::NotConnected))); // Try to write while not connected let data = b"test"; let result = channel.try_write(data); - assert!(matches!(result, Err(SendError::SocketInInvalidState))); + assert!(matches!(result, Err(ChannelWriteError::NotConnected))); } #[test] @@ -1105,7 +1138,31 @@ mod tests { // Should fail to read let mut buf = [0u8; 32]; let result = channel.try_read(&mut buf, super::super::ReceiveFlags::empty(), None); - assert!(matches!(result, Err(ReceiveError::SocketInInvalidState))); + assert!(matches!(result, Err(ChannelReadError::ReadShutdown))); + } + + #[test] + fn stream_channel_closed_after_connected_drains_rx_before_eof() { + let channel: StreamSocketChannel = StreamSocketChannel::new(); + channel.set_state(SocketState::Connected); + + let data = b"data"; + channel.push_rx_data_with(|buf: &mut [u8]| { + let to_copy = core::cmp::min(buf.len(), data.len()); + buf[..to_copy].copy_from_slice(&data[..to_copy]); + to_copy + }); + channel.set_state(SocketState::Closed); + + let mut buf = [0u8; 32]; + let read = channel + .try_read(&mut buf, super::super::ReceiveFlags::empty(), None) + .unwrap(); + assert_eq!(read, data.len()); + assert_eq!(&buf[..read], data); + + let result = channel.try_read(&mut buf, super::super::ReceiveFlags::empty(), None); + assert!(matches!(result, Err(ChannelReadError::ConnectionClosed))); } #[test] @@ -1118,7 +1175,7 @@ mod tests { // Should fail to write let result = channel.try_write(b"data"); - assert!(matches!(result, Err(SendError::SocketInInvalidState))); + assert!(matches!(result, Err(ChannelWriteError::WriteShutdown))); } #[test] @@ -1178,7 +1235,6 @@ mod tests { fn stream_channel_io_events() { let channel: StreamSocketChannel = StreamSocketChannel::new(); - // Closed state should have HUP let events = channel.check_io_events(); assert!(events.contains(Events::HUP)); @@ -1310,7 +1366,7 @@ mod tests { // Next send should fail let result = channel.send_to(&[99], Some(DUMMY_ADDR)); - assert!(matches!(result, Err(SendError::BufferFull))); + assert!(matches!(result, Err(ChannelWriteError::BufferFull))); } #[test] @@ -1319,7 +1375,10 @@ mod tests { // Sending without an address on an unconnected socket should fail let result = channel.send_to(&[1, 2, 3], None); - assert!(matches!(result, Err(SendError::DestinationAddressRequired))); + assert!(matches!( + result, + Err(ChannelWriteError::DestinationAddressRequired) + )); } #[test] diff --git a/litebox_common_linux/src/errno/mod.rs b/litebox_common_linux/src/errno/mod.rs index 0aa23aae2..044959b6b 100644 --- a/litebox_common_linux/src/errno/mod.rs +++ b/litebox_common_linux/src/errno/mod.rs @@ -401,6 +401,19 @@ impl From for Errno { } } +impl TryFrom for litebox::net::errors::SocketAsyncError { + type Error = Errno; + + fn try_from(value: Errno) -> Result { + match value { + Errno::ECONNREFUSED => Ok(litebox::net::errors::SocketAsyncError::ConnectionRefused), + Errno::ECONNRESET => Ok(litebox::net::errors::SocketAsyncError::ConnectionReset), + Errno::ETIMEDOUT => Ok(litebox::net::errors::SocketAsyncError::TimedOut), + _ => Err(value), + } + } +} + impl From for Errno { fn from(value: litebox::net::errors::LocalAddrError) -> Self { match value { @@ -461,6 +474,21 @@ impl From for Errno { } } +impl From for Errno { + fn from(value: litebox::net::socket_channel::ChannelWriteError) -> Self { + match value { + litebox::net::socket_channel::ChannelWriteError::WriteShutdown + | litebox::net::socket_channel::ChannelWriteError::NotConnected + | litebox::net::socket_channel::ChannelWriteError::ConnectionClosed => Errno::EPIPE, + litebox::net::socket_channel::ChannelWriteError::Unaddressable => Errno::EINVAL, + litebox::net::socket_channel::ChannelWriteError::BufferFull => Errno::EAGAIN, + litebox::net::socket_channel::ChannelWriteError::DestinationAddressRequired => { + Errno::EDESTADDRREQ + } + } + } +} + impl From for Errno { fn from(value: litebox::net::errors::ReceiveError) -> Self { match value { diff --git a/litebox_common_linux/src/lib.rs b/litebox_common_linux/src/lib.rs index 7f9f85670..58a43aa6b 100644 --- a/litebox_common_linux/src/lib.rs +++ b/litebox_common_linux/src/lib.rs @@ -1799,6 +1799,9 @@ bitflags::bitflags! { const TRUNC = 0x20; /// `MSG_WAITALL`: wait for the full amount of data const WAITALL = 0x100; + /// `MSG_WAITFORONE`: `recvmmsg` only — turn on `MSG_DONTWAIT` after the + /// first message has been received. + const WAITFORONE = 0x10000; /// const _ = !0; } @@ -1878,7 +1881,9 @@ impl Copy for UserMsgHdr { + /// the per-message `msghdr` pub msg_hdr: UserMsgHdr, + /// bytes transmitted for this entry, written back by the kernel pub msg_len: u32, #[cfg(target_pointer_width = "64")] _pad: u32, @@ -2140,6 +2145,13 @@ pub enum SyscallRequest { msg: Platform::RawMutPointer>, flags: ReceiveFlags, }, + Recvmmsg { + sockfd: i32, + msgvec: Platform::RawMutPointer>, + vlen: u32, + flags: ReceiveFlags, + timeout: TimeParam, + }, Shutdown { sockfd: i32, how: i32, @@ -2605,6 +2617,13 @@ impl SyscallRequest { Sysno::sendmmsg => sys_req!(Sendmmsg { sockfd, msgvec:*, vlen, flags }), Sysno::recvfrom => sys_req!(Recvfrom { sockfd, buf:*, len, flags, addr:*, addrlen:*, }), Sysno::recvmsg => sys_req!(Recvmsg { sockfd, msg:*, flags }), + Sysno::recvmmsg => sys_req!(Recvmmsg { + sockfd, + msgvec:*, + vlen, + flags, + timeout: { =*> TimeParam::timespec_old } + }), Sysno::shutdown => sys_req!(Shutdown { sockfd, how }), Sysno::bind => sys_req!(Bind { sockfd, sockaddr:*, addrlen }), Sysno::listen => sys_req!(Listen { sockfd, backlog }), diff --git a/litebox_runner_linux_userland/tests/recvmmsg.c b/litebox_runner_linux_userland/tests/recvmmsg.c new file mode 100644 index 000000000..609d34db1 --- /dev/null +++ b/litebox_runner_linux_userland/tests/recvmmsg.c @@ -0,0 +1,470 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int g_fail = 0; + +static void check(int cond, const char *what) { + if (cond) { + printf(" PASS: %s\n", what); + } else { + printf(" FAIL: %s\n", what); + g_fail = 1; + } +} + +// Use the raw syscall so we exercise exactly what LiteBox intercepts; glibc's +// wrapper would otherwise be free to massage arguments before reaching the +// kernel. The libc prototype takes a non-const timespec, so do likewise. +static long sys_recvmmsg(int fd, struct mmsghdr *msgvec, unsigned int vlen, + int flags, struct timespec *timeout) { + return syscall(SYS_recvmmsg, fd, msgvec, vlen, flags, timeout); +} + +// Helper: build a vlen-sized array of mmsghdrs, each with one iov pointing at +// the corresponding row of `bufs`. msg_len is poisoned so we can prove the +// kernel wrote it. +static void build_recv_hdrs(struct mmsghdr *hdrs, struct iovec *iov, + char (*bufs)[64], unsigned int vlen) { + memset(hdrs, 0xAB, sizeof(*hdrs) * vlen); + for (unsigned int i = 0; i < vlen; i++) { + memset(bufs[i], 0, 64); + iov[i].iov_base = bufs[i]; + iov[i].iov_len = 63; // leave room for a NUL + memset(&hdrs[i].msg_hdr, 0, sizeof(hdrs[i].msg_hdr)); + hdrs[i].msg_hdr.msg_iov = &iov[i]; + hdrs[i].msg_hdr.msg_iovlen = 1; + hdrs[i].msg_len = 0xDEADBEEF; + } +} + +// --------------------------------------------------------------------------- +// Test 1: pre-buffer three datagrams and drain them in one recvmmsg call. +// Default flags=0 still works here because the data is already queued, so the +// "block until vlen messages" semantics never have to actually wait. +// --------------------------------------------------------------------------- +static void test_three_messages(void) { + puts("Test 1: recvmmsg drains multiple queued datagrams in one call"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + const char *payloads[3] = {"hello", "world!!", "third-msg"}; + for (int i = 0; i < 3; i++) { + ssize_t n = send(sv[0], payloads[i], strlen(payloads[i]), 0); + if (n != (ssize_t)strlen(payloads[i])) { + perror("send"); + exit(2); + } + } + + struct iovec iov[3]; + struct mmsghdr hdrs[3]; + char bufs[3][64]; + build_recv_hdrs(hdrs, iov, bufs, 3); + + errno = 0; + long n = sys_recvmmsg(sv[1], hdrs, 3, 0, NULL); + printf(" recvmmsg returned %ld (errno=%d %s)\n", n, errno, + n < 0 ? strerror(errno) : "-"); + check(n == 3, "recvmmsg returned 3 (number of messages received)"); + + for (int i = 0; i < 3; i++) { + unsigned int got = hdrs[i].msg_len; + unsigned int want = (unsigned int)strlen(payloads[i]); + if (got != want) { + printf(" msg_len[%d] = %u, want %u\n", i, got, want); + } + check(got == want, "msg_len matches payload length for each entry"); + check(strcmp(bufs[i], payloads[i]) == 0, + "iov buffer holds the expected datagram payload"); + } + + close(sv[0]); + close(sv[1]); +} + +// --------------------------------------------------------------------------- +// Test 2: vlen == 0 returns 0 immediately, no errno, no work done. +// --------------------------------------------------------------------------- +static void test_vlen_zero(void) { + puts("Test 2: recvmmsg with vlen == 0 returns 0"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + errno = 0; + long n = sys_recvmmsg(sv[1], NULL, 0, 0, NULL); + printf(" recvmmsg returned %ld (errno=%d %s)\n", n, errno, + n < 0 ? strerror(errno) : "-"); + check(n == 0, "vlen=0 returns 0"); + + close(sv[0]); + close(sv[1]); +} + +// --------------------------------------------------------------------------- +// Test 3: errno mapping for bad fd / bad msgvec pointer (when no message has +// been received yet). +// --------------------------------------------------------------------------- +static void test_errno_paths(void) { + puts("Test 3: recvmmsg errno on bad fd / bad msgvec pointer"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + errno = 0; + long n = sys_recvmmsg(-1, NULL, 1, 0, NULL); + printf(" fd=-1 vlen=1: ret=%ld errno=%d (%s)\n", n, errno, + n < 0 ? strerror(errno) : "-"); + check(n == -1 && errno == EBADF, "bad fd returns EBADF"); + + errno = 0; + n = sys_recvmmsg(9999, NULL, 1, 0, NULL); + printf(" fd=9999 vlen=1: ret=%ld errno=%d (%s)\n", n, errno, + n < 0 ? strerror(errno) : "-"); + check(n == -1 && errno == EBADF, "unused fd returns EBADF"); + + errno = 0; + n = sys_recvmmsg(sv[1], NULL, 1, MSG_DONTWAIT, NULL); + printf(" fd=ok msgvec=NULL vlen=1 DONTWAIT: ret=%ld errno=%d (%s)\n", n, + errno, n < 0 ? strerror(errno) : "-"); + check(n == -1 && errno == EFAULT, "NULL msgvec with vlen>0 returns EFAULT"); + + close(sv[0]); + close(sv[1]); +} + +// --------------------------------------------------------------------------- +// Test 4: MSG_DONTWAIT on an empty socket returns -1 / EAGAIN. +// --------------------------------------------------------------------------- +static void test_dontwait_empty(void) { + puts("Test 4: MSG_DONTWAIT on empty queue returns EAGAIN"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + struct iovec iov[2]; + struct mmsghdr hdrs[2]; + char bufs[2][64]; + build_recv_hdrs(hdrs, iov, bufs, 2); + + errno = 0; + long n = sys_recvmmsg(sv[1], hdrs, 2, MSG_DONTWAIT, NULL); + printf(" recvmmsg DONTWAIT empty: ret=%ld errno=%d (%s)\n", n, errno, + n < 0 ? strerror(errno) : "-"); + check(n == -1 && (errno == EAGAIN || errno == EWOULDBLOCK), + "empty queue with DONTWAIT returns EAGAIN/EWOULDBLOCK"); + + close(sv[0]); + close(sv[1]); +} + +// --------------------------------------------------------------------------- +// Test 5: MSG_DONTWAIT — partial drain. Pre-buffer two datagrams, ask for +// five; we should get two back and msg_len for those two should be set. +// --------------------------------------------------------------------------- +static void test_dontwait_partial(void) { + puts("Test 5: MSG_DONTWAIT returns however many are queued (partial)"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + const char *payloads[2] = {"alpha", "beta!!"}; + for (int i = 0; i < 2; i++) { + ssize_t n = send(sv[0], payloads[i], strlen(payloads[i]), 0); + if (n != (ssize_t)strlen(payloads[i])) { + perror("send"); + exit(2); + } + } + + struct iovec iov[5]; + struct mmsghdr hdrs[5]; + char bufs[5][64]; + build_recv_hdrs(hdrs, iov, bufs, 5); + + errno = 0; + long n = sys_recvmmsg(sv[1], hdrs, 5, MSG_DONTWAIT, NULL); + printf(" recvmmsg DONTWAIT vlen=5 queued=2: ret=%ld errno=%d (%s)\n", n, + errno, n < 0 ? strerror(errno) : "-"); + check(n == 2, "returns 2 (number queued)"); + for (int i = 0; i < 2; i++) { + unsigned int got = hdrs[i].msg_len; + unsigned int want = (unsigned int)strlen(payloads[i]); + if (got != want) { + printf(" msg_len[%d] = %u, want %u\n", i, got, want); + } + check(got == want, "msg_len matches payload length"); + check(strcmp(bufs[i], payloads[i]) == 0, "payload arrived in iov"); + } + + close(sv[0]); + close(sv[1]); +} + +// --------------------------------------------------------------------------- +// Test: tv_nsec/tv_sec validation. Linux validates the timespec before the +// fd/msgvec — `poll_select_set_timeout` runs first in `do_recvmmsg`. +// --------------------------------------------------------------------------- +static void test_bad_timespec(void) { + puts("Test: invalid timespec returns EINVAL (before EBADF/EFAULT)"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + struct iovec iov[1]; + struct mmsghdr hdrs[1]; + char bufs[1][64]; + build_recv_hdrs(hdrs, iov, bufs, 1); + + // tv_sec < 0 + struct timespec ts_neg_sec = {-1, 0}; + errno = 0; + long n = sys_recvmmsg(sv[1], hdrs, 1, MSG_DONTWAIT, &ts_neg_sec); + printf(" tv_sec=-1: ret=%ld errno=%d (%s)\n", n, errno, + n < 0 ? strerror(errno) : "-"); + check(n == -1 && errno == EINVAL, "negative tv_sec returns EINVAL"); + + // tv_nsec >= 1_000_000_000 + struct timespec ts_big_nsec = {0, 1000000000}; + errno = 0; + n = sys_recvmmsg(sv[1], hdrs, 1, MSG_DONTWAIT, &ts_big_nsec); + printf(" tv_nsec=1e9: ret=%ld errno=%d (%s)\n", n, errno, + n < 0 ? strerror(errno) : "-"); + check(n == -1 && errno == EINVAL, "tv_nsec >= 1e9 returns EINVAL"); + + // Validation precedes EBADF + errno = 0; + n = sys_recvmmsg(-1, hdrs, 1, MSG_DONTWAIT, &ts_neg_sec); + printf(" fd=-1 + tv_sec=-1: ret=%ld errno=%d (%s)\n", n, errno, + n < 0 ? strerror(errno) : "-"); + check(n == -1 && errno == EINVAL, + "timespec validation runs before fd check"); + + close(sv[0]); + close(sv[1]); +} + +// --------------------------------------------------------------------------- +// Test: non-NULL timeout on an empty queue with MSG_DONTWAIT still returns +// EAGAIN (DONTWAIT short-circuits the inner recvmsg before the deadline ever +// matters). +// --------------------------------------------------------------------------- +static void test_timeout_dontwait_empty(void) { + puts("Test: timeout + DONTWAIT on empty queue returns EAGAIN"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + struct iovec iov[2]; + struct mmsghdr hdrs[2]; + char bufs[2][64]; + build_recv_hdrs(hdrs, iov, bufs, 2); + + struct timespec ts = {0, 10000000}; // 10ms + errno = 0; + long n = sys_recvmmsg(sv[1], hdrs, 2, MSG_DONTWAIT, &ts); + printf(" recvmmsg DONTWAIT timeout=10ms empty: ret=%ld errno=%d (%s)\n", n, + errno, n < 0 ? strerror(errno) : "-"); + check(n == -1 && (errno == EAGAIN || errno == EWOULDBLOCK), + "DONTWAIT on empty with timeout returns EAGAIN"); + + close(sv[0]); + close(sv[1]); +} + +// --------------------------------------------------------------------------- +// Test: zero timespec ({0,0}) on a queued socket reads exactly ONE message +// before the deadline check fires. `poll_select_set_timeout` parks the +// deadline at {0,0}, which compares as already past after the first recvmsg +// returns, so the loop exits with datagrams=1. +// --------------------------------------------------------------------------- +static void test_timeout_zero_caps_at_one(void) { + puts("Test: timeout={0,0} reads exactly 1 message even with vlen > 1"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + const char *payloads[3] = {"first", "second", "third!"}; + for (int i = 0; i < 3; i++) { + ssize_t sent = send(sv[0], payloads[i], strlen(payloads[i]), 0); + if (sent != (ssize_t)strlen(payloads[i])) { + perror("send"); + exit(2); + } + } + + struct iovec iov[3]; + struct mmsghdr hdrs[3]; + char bufs[3][64]; + build_recv_hdrs(hdrs, iov, bufs, 3); + + struct timespec ts = {0, 0}; + errno = 0; + long n = sys_recvmmsg(sv[1], hdrs, 3, 0, &ts); + printf(" recvmmsg ts={0,0} queued=3 vlen=3: ret=%ld errno=%d (%s)\n", n, + errno, n < 0 ? strerror(errno) : "-"); + check(n == 1, "timeout={0,0} returns 1"); + check(hdrs[0].msg_len == strlen(payloads[0]), + "msg_len[0] matches first payload length"); + check(strcmp(bufs[0], payloads[0]) == 0, + "iov[0] holds the first payload"); + // msg_len[1] was poisoned to 0xDEADBEEF and shouldn't have been written. + check(hdrs[1].msg_len == 0xDEADBEEF, "msg_len[1] left untouched"); + + close(sv[0]); + close(sv[1]); +} + +// --------------------------------------------------------------------------- +// Test: a generous timeout doesn't truncate a multi-message drain. With a 5s +// deadline and three queued datagrams, all three are read before the deadline +// is even close. +// --------------------------------------------------------------------------- +static void test_timeout_generous_drains_all(void) { + puts("Test: generous timeout drains all queued messages"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + const char *payloads[3] = {"aaa", "bbbb", "ccccc"}; + for (int i = 0; i < 3; i++) { + ssize_t sent = send(sv[0], payloads[i], strlen(payloads[i]), 0); + if (sent != (ssize_t)strlen(payloads[i])) { + perror("send"); + exit(2); + } + } + + struct iovec iov[5]; + struct mmsghdr hdrs[5]; + char bufs[5][64]; + build_recv_hdrs(hdrs, iov, bufs, 5); + + struct timespec ts = {5, 0}; + errno = 0; + long n = sys_recvmmsg(sv[1], hdrs, 5, MSG_DONTWAIT, &ts); + printf(" recvmmsg DONTWAIT ts=5s queued=3 vlen=5: ret=%ld errno=%d (%s)\n", + n, errno, n < 0 ? strerror(errno) : "-"); + check(n == 3, "drained all 3 queued messages"); + for (int i = 0; i < 3; i++) { + check(hdrs[i].msg_len == strlen(payloads[i]), + "msg_len matches payload length"); + check(strcmp(bufs[i], payloads[i]) == 0, "payload arrived in iov"); + } + // Linux writes the remaining time back into the user's timespec on + // success (see `put_timespec64` in `__sys_recvmmsg`). Native probe with + // queued(3)/vlen=5 reports e.g. {4, 999997470}: drain takes microseconds + // so the residual is just-under 5s, but the kernel did update it. + printf(" remaining timespec: {%ld, %ld}\n", (long)ts.tv_sec, (long)ts.tv_nsec); + check(ts.tv_sec < 5, "remaining tv_sec decremented below original 5"); + check(ts.tv_sec >= 0, "remaining tv_sec did not go negative"); + + close(sv[0]); + close(sv[1]); +} + +// --------------------------------------------------------------------------- +// Test 6: MSG_WAITFORONE — pre-buffer two datagrams, ask for five with +// MSG_WAITFORONE. We should get exactly two (the first read blocks, but data +// is already there; subsequent reads are non-blocking). +// --------------------------------------------------------------------------- +static void test_waitforone(void) { + puts("Test 6: MSG_WAITFORONE drains pre-buffered datagrams"); + + int sv[2]; + if (socketpair(AF_UNIX, SOCK_DGRAM, 0, sv) < 0) { + perror("socketpair"); + exit(2); + } + + const char *payloads[2] = {"one", "two-two"}; + for (int i = 0; i < 2; i++) { + ssize_t n = send(sv[0], payloads[i], strlen(payloads[i]), 0); + if (n != (ssize_t)strlen(payloads[i])) { + perror("send"); + exit(2); + } + } + + struct iovec iov[5]; + struct mmsghdr hdrs[5]; + char bufs[5][64]; + build_recv_hdrs(hdrs, iov, bufs, 5); + + errno = 0; + long n = sys_recvmmsg(sv[1], hdrs, 5, MSG_WAITFORONE, NULL); + printf(" recvmmsg WAITFORONE vlen=5 queued=2: ret=%ld errno=%d (%s)\n", n, + errno, n < 0 ? strerror(errno) : "-"); + check(n == 2, "returns 2 (drains all queued after first)"); + for (int i = 0; i < 2; i++) { + unsigned int got = hdrs[i].msg_len; + unsigned int want = (unsigned int)strlen(payloads[i]); + check(got == want, "msg_len matches payload length"); + check(strcmp(bufs[i], payloads[i]) == 0, "payload arrived in iov"); + } + + close(sv[0]); + close(sv[1]); +} + +int main(void) { + puts("recvmmsg parity test"); + test_three_messages(); + test_vlen_zero(); + test_errno_paths(); + test_dontwait_empty(); + test_dontwait_partial(); + test_waitforone(); + test_bad_timespec(); + test_timeout_dontwait_empty(); + test_timeout_zero_caps_at_one(); + test_timeout_generous_drains_all(); + + if (g_fail) { + puts("\nRESULT: BUG(S) REPRODUCED"); + return 1; + } + puts("\nAll recvmmsg tests passed."); + return 0; +} diff --git a/litebox_shim_linux/src/lib.rs b/litebox_shim_linux/src/lib.rs index 1873ea359..67a8fadf6 100644 --- a/litebox_shim_linux/src/lib.rs +++ b/litebox_shim_linux/src/lib.rs @@ -703,6 +703,13 @@ impl Task { addrlen, } => self.sys_recvfrom(sockfd, buf, len, flags, addr, addrlen), SyscallRequest::Recvmsg { sockfd, msg, flags } => self.sys_recvmsg(sockfd, msg, flags), + SyscallRequest::Recvmmsg { + sockfd, + msgvec, + vlen, + flags, + timeout, + } => self.sys_recvmmsg(sockfd, msgvec, vlen, flags, timeout), SyscallRequest::Shutdown { sockfd, how } => syscall!(sys_shutdown(sockfd, how)), SyscallRequest::Bind { sockfd, diff --git a/litebox_shim_linux/src/syscalls/net.rs b/litebox_shim_linux/src/syscalls/net.rs index a155261c5..997ccb632 100644 --- a/litebox_shim_linux/src/syscalls/net.rs +++ b/litebox_shim_linux/src/syscalls/net.rs @@ -5,7 +5,7 @@ use core::{ ffi::CStr, - mem::offset_of, + mem::{offset_of, size_of}, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, }; @@ -21,15 +21,15 @@ use litebox::{ net::{ CloseBehavior, TcpOptionData, errors::AcceptError, - socket_channel::{NetworkProxy, SocketState}, + socket_channel::{ChannelReadError, ChannelWriteError, NetworkProxy, SocketState}, }, - platform::{RawConstPointer as _, RawMutPointer as _}, + platform::{Instant as _, RawConstPointer as _, RawMutPointer as _, TimeProvider as _}, utils::TruncateExt as _, }; use litebox_common_linux::{ AddressFamily, FileDescriptorFlags, IPProtocol, ReceiveFlags, SendFlags, ShutdownHow, - SockFlags, SockType, SocketOption, SocketOptionName, TcpOption, UnixProtocol, errno::Errno, - signal::Signal, + SockFlags, SockType, SocketOption, SocketOptionName, TcpOption, UnixProtocol, UserMmsgHdr, + UserMsgHdr, errno::Errno, signal::Signal, }; use zerocopy::{FromBytes, Immutable, IntoBytes}; @@ -41,7 +41,7 @@ use crate::{ }; /// Linux's hard cap on the number of iovecs per `*msg`-style call, and on the -/// number of entries per `*mmsg`-style call. See `UIO_MAXIOV` in ``. +/// number of entries per `sendmmsg`. See `UIO_MAXIOV` in ``. const UIO_MAXIOV: usize = 1024; macro_rules! convert_flags { @@ -772,7 +772,7 @@ impl GlobalState { Ok(0) if buf.is_empty() => Ok(0), Ok(0) => Err(TryOpError::TryAgain), Ok(n) => Ok(n), - Err(litebox::net::errors::SendError::BufferFull) if is_empty_stream => Ok(0), + Err(ChannelWriteError::BufferFull) if is_empty_stream => Ok(0), Err(e) => Err(TryOpError::Other(Errno::from(e))), }, ) @@ -830,7 +830,12 @@ impl GlobalState { || match proxy.try_read(buf, new_flags, source_addr.as_deref_mut()) { Ok(0) => Err(TryOpError::TryAgain), Ok(n) => Ok(n), - Err(e) => Err(TryOpError::Other(Errno::from(e))), + Err(ChannelReadError::ReadShutdown) => Ok(0), + Err(ChannelReadError::ConnectionClosed) => match proxy.get_async_error(true) { + Some(err) => Err(TryOpError::Other(err.into())), + None => Ok(0), + }, + Err(ChannelReadError::NotConnected) => Err(TryOpError::Other(Errno::ENOTCONN)), }, ) .map_err(Errno::from) @@ -1642,6 +1647,14 @@ impl Task { return Err(Errno::EINVAL); } + self.do_recvmsg(sockfd, msg_ptr, flags) + } + fn do_recvmsg( + &self, + sockfd: u32, + msg_ptr: MutPtr>, + flags: ReceiveFlags, + ) -> Result { let msg = msg_ptr.read_at_offset(0).ok_or(Errno::EFAULT)?; // Copy fields out of the packed struct to avoid unaligned references. @@ -1749,6 +1762,117 @@ impl Task { Ok(total_received) } + /// Handle syscall `recvmmsg` + pub(crate) fn sys_recvmmsg( + &self, + fd: i32, + msgvec: MutPtr>, + vlen: u32, + flags: ReceiveFlags, + timeout: litebox_common_linux::TimeParam, + ) -> Result { + let supported_flags = + ReceiveFlags::DONTWAIT | ReceiveFlags::TRUNC | ReceiveFlags::WAITFORONE; + if flags.intersects(supported_flags.complement()) { + log_unsupported!("Unsupported recvmmsg flags: {:?}", flags); + return Err(Errno::EINVAL); + } + + // Linux's `do_recvmmsg` validates the timespec before looking up the fd, + // so a bad timeout takes precedence over EBADF. + let timeout_duration = timeout.read()?; + + let Ok(sockfd) = u32::try_from(fd) else { + return Err(Errno::EBADF); + }; + + let vlen = vlen as usize; + + // Linux looks up the fd before touching vlen/msgvec, so a bogus fd + // takes priority over a bogus msgvec pointer or vlen == 0. + let inet_proxy = self.files.borrow().with_socket( + &self.global, + sockfd, + |fd| self.global.get_proxy(fd).map(Some), + |_| Ok(None), + )?; + + if vlen == 0 { + return Ok(0); + } + + // A `None` deadline means either no user-supplied timeout or a saturating overflow + // — both are treated as "no deadline". + let deadline = timeout_duration.and_then(|d| self.global.platform.now().checked_add(d)); + + let stride = size_of::>(); + let msg_len_off = offset_of!(UserMmsgHdr, msg_len); + let msgvec_base = msgvec.as_usize(); + let msgvec_len = vlen.checked_mul(stride).ok_or(Errno::EFAULT)?; + if msgvec_base.checked_add(msgvec_len).is_none() { + return Err(Errno::EFAULT); + } + + // WAITFORONE is mmsg-only; the inner recvmsg doesn't recognize it. + let waitforone = flags.contains(ReceiveFlags::WAITFORONE); + let mut iter_flags = flags.difference(ReceiveFlags::WAITFORONE); + let mut received: usize = 0; + let mut last_err: Option = None; + let mut async_error_to_restore = None; + for i in 0..vlen { + let base = msgvec_base + i * stride; + let inner_ptr = MutPtr::>::from_usize(base); + let n = match self.do_recvmsg(sockfd, inner_ptr, iter_flags) { + Ok(n) => n, + Err(e) => { + if received > 0 { + async_error_to_restore = e.try_into().ok(); + } + last_err = Some(e); + break; + } + }; + let msg_len_ptr = MutPtr::::from_usize(base + msg_len_off); + if msg_len_ptr.write_at_offset(0, n.trunc()).is_none() { + last_err = Some(Errno::EFAULT); + break; + } + received += 1; + if waitforone { + iter_flags.insert(ReceiveFlags::DONTWAIT); + } + + // Per the man page, the timeout is checked only after the receipt of each datagram. + if let Some(deadline) = deadline + && self.global.platform.now() >= deadline + { + break; + } + } + + if received == 0 { + // The only way to exit the loop with received=0 is via an inner + // recvmsg error; EAGAIN is the conservative fallback for the + // structurally unreachable case. + return Err(last_err.unwrap_or(Errno::EAGAIN)); + } + + // Stash the suppressed async socket error back onto the socket. + if let (Some(async_error), Some(proxy)) = (async_error_to_restore, inet_proxy) { + proxy.set_async_error(async_error); + } + + // Match Linux's `__sys_recvmmsg`: the remaining timespec is only + // written back when at least one datagram was received. A write + // EFAULT here overrides the success return, mirroring Linux. + let remaining = deadline + .and_then(|d| d.checked_duration_since(&self.global.platform.now())) + .unwrap_or(core::time::Duration::ZERO); + timeout.write(remaining)?; + + Ok(received) + } + pub(crate) fn sys_setsockopt( &self, sockfd: i32, diff --git a/litebox_shim_linux/src/transport.rs b/litebox_shim_linux/src/transport.rs index b4628a60a..ce450f53e 100644 --- a/litebox_shim_linux/src/transport.rs +++ b/litebox_shim_linux/src/transport.rs @@ -7,7 +7,7 @@ use alloc::boxed::Box; use alloc::sync::Arc; use litebox::fs::nine_p::transport; -use litebox::net::socket_channel::NetworkProxy; +use litebox::net::socket_channel::{ChannelWriteError, NetworkProxy}; use litebox::net::{ReceiveFlags, SendFlags}; use litebox_common_linux::{SockFlags, SockType, errno::Errno}; @@ -121,7 +121,7 @@ impl transport::Write for ShimTransport { loop { match self.proxy.try_write(buf, SendFlags::empty(), None) { Ok(n) => return Ok(n), - Err(litebox::net::errors::SendError::BufferFull) => { + Err(ChannelWriteError::BufferFull) => { // TX ring full — spin until space opens up. core::hint::spin_loop(); }