diff --git a/Cargo.toml b/Cargo.toml index 59b323f..8047697 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.1.3" +version = "0.1.4" edition = "2024" rust-version = "1.85" license = "MIT" diff --git a/msg-socket/src/req/conn_manager.rs b/msg-socket/src/req/conn_manager.rs new file mode 100644 index 0000000..9a87dfc --- /dev/null +++ b/msg-socket/src/req/conn_manager.rs @@ -0,0 +1,213 @@ +use std::{ + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; + +use bytes::Bytes; +use futures::{Future, FutureExt, SinkExt, StreamExt}; +use msg_common::span::{EnterSpan as _, WithSpan}; +use tokio_util::codec::Framed; +use tracing::Instrument; + +use crate::{ConnectionState, ExponentialBackoff}; + +use msg_transport::{Address, MeteredIo, Transport}; +use msg_wire::{auth, reqrep}; + +/// A connection task that connects to a server and returns the underlying IO object. +type ConnTask = Pin> + Send>>; + +/// A connection from the transport to a server. +/// +/// # Usage of Framed +/// [`Framed`] is used for encoding and decoding messages ("frames"). +/// Usually, [`Framed`] has its own internal buffering mechanism, that's respected +/// when calling `poll_ready` and configured by [`Framed::set_backpressure_boundary`]. +/// +/// However, we don't use `poll_ready` here, and instead we flush every time we write a message to +/// the framed buffer. +pub(crate) type Conn = Framed, reqrep::Codec>; + +/// A connection controller that manages the connection to a server with an exponential backoff. +pub(crate) type ConnCtl = ConnectionState, ExponentialBackoff, A>; + +/// Manages the connection lifecycle: connecting, reconnecting, and maintaining the connection. +pub(crate) struct ConnManager, A: Address> { + /// The connection task which handles the connection to the server. + conn_task: Option>>, + /// The transport controller, wrapped in a [`ConnectionState`] for backoff. + /// The [`Framed`] object can send and receive messages from the socket. + conn_ctl: ConnCtl, + /// The transport for this socket. + transport: T, + /// The address of the server. + addr: A, + /// Transport stats for metering IO. + transport_stats: Arc>, + /// Authentication token for the connection. + auth_token: Option, + + /// A span to use for connection-related logging. + span: tracing::Span, +} + +/// Perform the authentication handshake with the server. +#[tracing::instrument(skip_all, "auth", fields(token = ?token))] +async fn authentication_handshake(mut io: T::Io, token: Bytes) -> Result +where + T: Transport, + A: Address, +{ + let mut conn = Framed::new(&mut io, auth::Codec::new_client()); + + conn.send(auth::Message::Auth(token)).await?; + tracing::debug!("sent auth, waiting ack from server"); + + // Wait for the response + let Some(res) = conn.next().await else { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "connection closed").into()); + }; + + match res { + Ok(auth::Message::Ack) => { + tracing::debug!("received ack"); + Ok(io) + } + Ok(msg) => { + tracing::error!(?msg, "unexpected ack result"); + Err(io::Error::new(io::ErrorKind::PermissionDenied, "rejected").into()) + } + Err(e) => Err(io::Error::new(io::ErrorKind::PermissionDenied, e).into()), + } +} + +impl ConnManager +where + T: Transport, + A: Address, +{ + pub(crate) fn new( + transport: T, + addr: A, + conn_ctl: ConnCtl, + transport_stats: Arc>, + auth_token: Option, + span: tracing::Span, + ) -> Self { + Self { conn_task: None, conn_ctl, transport, addr, transport_stats, auth_token, span } + } + + /// Start the connection task to the server, handling authentication if necessary. + /// The result will be polled by the driver and re-tried according to the backoff policy. + fn try_connect(&mut self) { + let connect = self.transport.connect(self.addr.clone()); + let token = self.auth_token.clone(); + + let task = async move { + let io = connect.await?; + + let Some(token) = token else { + return Ok(io); + }; + + authentication_handshake::(io, token).await + } + .in_current_span(); + + // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`] + self.conn_task = Some(WithSpan::current(Box::pin(task))); + } + + /// Reset the connection state to inactive, so that it will be re-tried. + /// This is done when the connection is closed or an error occurs. + #[inline] + pub(crate) fn reset_connection(&mut self) { + self.conn_ctl = ConnectionState::Inactive { + addr: self.addr.clone(), + backoff: ExponentialBackoff::new(Duration::from_millis(20), 16), + }; + } + + /// Returns a mutable reference to the connection channel if it is active. + #[inline] + pub(crate) fn active_connection(&mut self) -> Option<&mut Conn> { + if let ConnectionState::Active { ref mut channel } = self.conn_ctl { + Some(channel) + } else { + None + } + } + + /// Poll connection management logic: connection task, backoff, and retry logic. + /// Loops until the connection is active, then returns a mutable reference to the channel. + /// + /// Note: this is not a `Future` impl because we want to return a reference; doing it in + /// a `Future` would require lifetime headaches or unsafe code. + /// + /// Returns: + /// * `Poll::Ready(Some(&mut channel))` if the connection is active + /// * `Poll::Ready(None)` if we should terminate (max retries exceeded) + /// * `Poll::Pending` if we need to wait for backoff + #[allow(clippy::type_complexity)] + pub(crate) fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + // Poll the active connection task, if any + if let Some(ref mut conn_task) = self.conn_task { + if let Poll::Ready(result) = conn_task.poll_unpin(cx).enter() { + // As soon as the connection task finishes, set it to `None`. + // - If it was successful, set the connection to active + // - If it failed, it will be re-tried until the backoff limit is reached. + self.conn_task = None; + + match result.inner { + Ok(io) => { + tracing::info!("connected"); + + let metered = MeteredIo::new(io, self.transport_stats.clone()); + let framed = Framed::new(metered, reqrep::Codec::new()); + self.conn_ctl = ConnectionState::Active { channel: framed }; + } + Err(e) => { + tracing::error!(?e, "failed to connect"); + } + } + } + } + + // If the connection is inactive, try to connect to the server or poll the backoff + // timer if we're already trying to connect. + if let ConnectionState::Inactive { backoff, .. } = &mut self.conn_ctl { + let Poll::Ready(item) = backoff.poll_next_unpin(cx) else { + return Poll::Pending; + }; + + let _span = tracing::info_span!(parent: &self.span, "connect").entered(); + + if let Some(duration) = item { + if self.conn_task.is_none() { + tracing::debug!(backoff = ?duration, "trying connection"); + self.try_connect(); + } else { + tracing::debug!( + backoff = ?duration, + "not retrying as there is already a connection task" + ); + } + } else { + tracing::error!("exceeded maximum number of retries, terminating connection"); + return Poll::Ready(None); + } + } + + if let ConnectionState::Active { ref mut channel } = self.conn_ctl { + return Poll::Ready(Some(channel)); + } + } + } +} diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index f06a673..e213d4b 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -1,48 +1,32 @@ use std::{ collections::VecDeque, - io, pin::Pin, sync::Arc, task::{Context, Poll, ready}, - time::{Duration, Instant}, + time::Instant, }; use bytes::Bytes; -use futures::{Future, FutureExt, SinkExt, StreamExt}; +use futures::{Future, SinkExt, StreamExt}; use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan}; use rustc_hash::FxHashMap; use tokio::{ sync::{mpsc, oneshot}, time::Interval, }; -use tokio_util::codec::Framed; -use tracing::Instrument; use super::{ReqError, ReqOptions}; -use crate::{ConnectionState, ExponentialBackoff, SendCommand, req::SocketState}; +use crate::{ + SendCommand, + req::{SocketState, conn_manager::ConnManager}, +}; -use msg_transport::{Address, MeteredIo, Transport}; +use msg_transport::{Address, Transport}; use msg_wire::{ - auth::{self}, compression::{Compressor, try_decompress_payload}, reqrep, }; -/// A connection task that connects to a server and returns the underlying IO object. -type ConnectionTask = Pin> + Send>>; - -/// A connection controller that manages the connection to a server with an exponential backoff. -/// -/// # Usage of Framed -/// [`Framed`] is used for encoding and decoding messages ("frames"). -/// Usually, [`Framed`] has its own internal buffering mechanism, that's respected -/// when calling `poll_ready` and configured by [`Framed::set_backpressure_boundary`]. -/// -/// However, we don't use `poll_ready` here, and instead we flush every time we write a message to -/// the framed buffer. -pub(crate) type ConnectionCtl = - ConnectionState, reqrep::Codec>, ExponentialBackoff, A>; - /// The request socket driver. Endless future that drives /// the socket forward. pub(crate) struct ReqDriver, A: Address> { @@ -54,15 +38,8 @@ pub(crate) struct ReqDriver, A: Address> { pub(crate) id_counter: u32, /// Commands from the socket. pub(crate) from_socket: mpsc::Receiver, - /// The transport for this socket. - pub(crate) transport: T, - /// The address of the server. - pub(crate) addr: A, - /// The connection task which handles the connection to the server. - pub(crate) conn_task: Option>>, - /// The transport controller, wrapped in a [`ConnectionState`] for backoff. - /// The [`Framed`] object can send and receive messages from the socket. - pub(crate) conn_state: ConnectionCtl, + /// Connection manager that handles connection lifecycle. + pub(crate) conn_manager: ConnManager, /// The timer for the write buffer linger. pub(crate) linger_timer: Option, /// The outgoing message queue. @@ -89,62 +66,11 @@ pub(crate) struct PendingRequest { sender: oneshot::Sender>, } -/// Perform the authentication handshake with the server. -#[tracing::instrument(skip_all, "auth", fields(token = ?token))] -async fn authentication_handshake(mut io: T::Io, token: Bytes) -> Result -where - T: Transport, - A: Address, -{ - let mut conn = Framed::new(&mut io, auth::Codec::new_client()); - - conn.send(auth::Message::Auth(token)).await?; - tracing::debug!("sent auth, waiting ack from server"); - - // Wait for the response - let Some(res) = conn.next().await else { - return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "connection closed").into()); - }; - - match res { - Ok(auth::Message::Ack) => { - tracing::debug!("received ack"); - Ok(io) - } - Ok(msg) => { - tracing::error!(?msg, "unexpected ack result"); - Err(io::Error::new(io::ErrorKind::PermissionDenied, "rejected").into()) - } - Err(e) => Err(io::Error::new(io::ErrorKind::PermissionDenied, e).into()), - } -} - impl ReqDriver where T: Transport, A: Address, { - /// Start the connection task to the server, handling authentication if necessary. - /// The result will be polled by the driver and re-tried according to the backoff policy. - fn try_connect(&mut self, addr: A) { - let connect = self.transport.connect(addr.clone()); - let token = self.options.auth_token.clone(); - - let task = async move { - let io = connect.await?; - - let Some(token) = token else { - return Ok(io); - }; - - authentication_handshake::(io, token).await - } - .in_current_span(); - - // FIX: coercion to BoxFuture for [`SpanExt::with_current_span`] - self.conn_task = Some(WithSpan::current(Box::pin(task))); - } - /// Handle an incoming message from the connection. fn on_message(&mut self, msg: reqrep::Message) { let Some(pending) = self.pending_requests.remove(&msg.id()).enter() else { @@ -237,16 +163,6 @@ where } } } - - /// Reset the connection state to inactive, so that it will be re-tried. - /// This is done when the connection is closed or an error occurs. - #[inline] - fn reset_connection(&mut self) { - self.conn_state = ConnectionState::Inactive { - addr: self.addr.clone(), - backoff: ExponentialBackoff::new(Duration::from_millis(20), 16), - }; - } } impl Future for ReqDriver @@ -261,58 +177,11 @@ where let span = this.span.clone(); loop { - // TODO: Group connection management together in a function or at a different level of - // abstraction. - - // Poll the active connection task, if any - if let Some(ref mut conn_task) = this.conn_task { - if let Poll::Ready(result) = conn_task.poll_unpin(cx).enter() { - // As soon as the connection task finishes, set it to `None`. - // - If it was successful, set the connection to active - // - If it failed, it will be re-tried until the backoff limit is reached. - this.conn_task = None; - - match result.inner { - Ok(io) => { - tracing::info!("connected"); - - let metered = - MeteredIo::new(io, Arc::clone(&this.socket_state.transport_stats)); - - let framed = Framed::new(metered, reqrep::Codec::new()); - this.conn_state = ConnectionState::Active { channel: framed }; - } - Err(e) => { - tracing::error!(?e, "failed to connect"); - } - } - } - } - - // If the connection is inactive, try to connect to the server or poll the backoff - // timer if we're already trying to connect. - if let ConnectionState::Inactive { ref mut backoff, ref addr } = this.conn_state { - let Poll::Ready(item) = backoff.poll_next_unpin(cx) else { return Poll::Pending }; - - let _span = tracing::info_span!(parent: &this.span, "connect").entered(); - - if let Some(duration) = item { - if this.conn_task.is_none() { - tracing::debug!(backoff = ?duration, "trying connection"); - this.try_connect(addr.clone()); - } else { - tracing::debug!(backoff = ?duration, "not retrying as there is already a connection task"); - } - } else { - tracing::error!("exceeded maximum number of retries, terminating connection"); - - return Poll::Ready(()); - } - } - - // If there is no active connection, continue polling the backoff - let ConnectionState::Active { ref mut channel } = this.conn_state else { - continue; + // Handle connection management: connection task, backoff, and retry logic + let channel = match this.conn_manager.poll(cx) { + Poll::Ready(Some(channel)) => channel, + Poll::Ready(None) => return Poll::Ready(()), + Poll::Pending => return Poll::Pending, }; // Check for incoming messages from the socket @@ -330,7 +199,7 @@ where ); // set the connection to inactive, so that it will be re-tried - this.reset_connection(); + this.conn_manager.reset_connection(); continue; } @@ -339,7 +208,7 @@ where tracing::warn!("connection closed, resetting connection state"); // set the connection to inactive, so that it will be re-tried - this.reset_connection(); + this.conn_manager.reset_connection(); continue; } @@ -363,7 +232,7 @@ where tracing::error!(err = ?e, "Failed to send message to socket"); // set the connection to inactive, so that it will be re-tried - this.reset_connection(); + this.conn_manager.reset_connection(); } } @@ -374,7 +243,7 @@ where if channel.write_buffer().len() >= this.options.write_buffer_size { if let Poll::Ready(Err(e)) = channel.poll_flush_unpin(cx) { tracing::error!(err = ?e, "Failed to flush connection"); - this.reset_connection(); + this.conn_manager.reset_connection(); continue; } @@ -388,7 +257,7 @@ where if !channel.write_buffer().is_empty() && linger_timer.poll_tick(cx).is_ready() { if let Poll::Ready(Err(e)) = channel.poll_flush_unpin(cx) { tracing::error!(err = ?e, "Failed to flush connection"); - this.reset_connection(); + this.conn_manager.reset_connection(); } } } @@ -410,9 +279,9 @@ where "socket dropped, shutting down backend and flushing connection" ); - if let ConnectionState::Active { ref mut channel } = this.conn_state { + if let Some(channel) = this.conn_manager.active_connection() { let _ = ready!(channel.poll_close_unpin(cx)); - }; + } return Poll::Ready(()); } diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index 85247fa..d17f758 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -13,6 +13,7 @@ use msg_wire::{ reqrep, }; +mod conn_manager; mod driver; mod socket; mod stats; diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index ecad547..6b6b082 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -23,7 +23,8 @@ use crate::{ ConnectionState, DRIVER_ID, ExponentialBackoff, ReqMessage, SendCommand, req::{ SocketState, - driver::{ConnectionCtl, ReqDriver}, + conn_manager::{ConnCtl, ConnManager}, + driver::ReqDriver, stats::ReqStats, }, stats::SocketStats, @@ -167,12 +168,7 @@ where } /// Internal method to initialize and spawn the driver. - fn spawn_driver( - &mut self, - endpoint: A, - transport: T, - conn_state: ConnectionCtl, - ) { + fn spawn_driver(&mut self, endpoint: A, transport: T, conn_ctl: ConnCtl) { // Initialize communication channels let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); @@ -191,19 +187,26 @@ where timer }); + // Create connection manager + let conn_manager = ConnManager::new( + transport, + endpoint, + conn_ctl, + Arc::clone(&self.state.transport_stats), + self.options.auth_token.clone(), + span.clone(), + ); + // Create the socket backend let driver: ReqDriver = ReqDriver { - addr: endpoint, options: Arc::clone(&self.options), socket_state: self.state.clone(), id_counter: 0, from_socket, - transport, - conn_state, + conn_manager, linger_timer, pending_requests, timeout_check_interval, - conn_task: None, egress_queue: Default::default(), compressor: self.compressor.clone(), id,