diff --git a/integration/rust/tests/integration/maintenance_mode.rs b/integration/rust/tests/integration/maintenance_mode.rs index 4f089b1d5..8659cda00 100644 --- a/integration/rust/tests/integration/maintenance_mode.rs +++ b/integration/rust/tests/integration/maintenance_mode.rs @@ -120,7 +120,7 @@ async fn test_maintenance_mode_parsing() { let result = admin.simple_query("MAINTENANCE INVALID").await; assert!(result.is_err()); - let result = admin.simple_query("MAINTENANCE ON EXTRA").await; + let result = admin.simple_query("MAINTENANCE ON DB_NAME EXTRA").await; assert!(result.is_err()); } @@ -156,6 +156,67 @@ async fn test_maintenance_mode_concurrent_operations() { tokio::try_join!(admin_task, client_task).unwrap(); } +#[tokio::test] +#[serial] +async fn test_maintenance_mode_single_database() { + let admin = admin_tokio().await; + let failover = connection_failover().await; + let mut pools = connections_sqlx().await; + let pgdog = pools.remove(0); // "pgdog" database + + // Clean slate. + admin.simple_query("MAINTENANCE OFF").await.unwrap(); + admin + .simple_query("MAINTENANCE OFF failover") + .await + .unwrap(); + + // Both databases work normally. + pgdog.execute("SELECT 1").await.unwrap(); + failover.execute("SELECT 1").await.unwrap(); + + // Put only the `failover` database into maintenance mode. + admin.simple_query("MAINTENANCE ON failover").await.unwrap(); + + // Queries to `pgdog` must keep working, unaffected by `failover`. + tokio::time::timeout(Duration::from_secs(1), pgdog.execute("SELECT 2")) + .await + .expect("pgdog query should not block while only failover is in maintenance") + .unwrap(); + + // Meanwhile, a query to `failover` should block until maintenance is lifted. + let failover_blocked = failover.clone(); + let blocked = tokio::spawn(async move { + failover_blocked.execute("SELECT 2").await.unwrap(); + }); + + sleep(Duration::from_millis(100)).await; + assert!( + !blocked.is_finished(), + "failover query should be blocked during maintenance" + ); + + // `pgdog` still works while the `failover` query is blocked. + tokio::time::timeout(Duration::from_secs(1), pgdog.execute("SELECT 3")) + .await + .expect("pgdog query should not block") + .unwrap(); + + // Lift maintenance on `failover`; the blocked query should now complete. + admin + .simple_query("MAINTENANCE OFF failover") + .await + .unwrap(); + + tokio::time::timeout(Duration::from_secs(5), blocked) + .await + .expect("failover query should complete after maintenance is lifted") + .unwrap(); + + pgdog.close().await; + failover.close().await; +} + #[tokio::test] #[serial] async fn test_maintenance_mode_transaction_behavior() { diff --git a/pgdog/src/admin/maintenance_mode.rs b/pgdog/src/admin/maintenance_mode.rs index 2c2c66079..fafeeac75 100644 --- a/pgdog/src/admin/maintenance_mode.rs +++ b/pgdog/src/admin/maintenance_mode.rs @@ -1,13 +1,21 @@ //! Turn maintenance mode on/off. +//! +//! Maintenance mode is special: it's completely independent from the config +//! and will hold true during config changes, e.g. when some databases disappear +//! and new ones are added. +//! +//! This is useful when changing the sharding config online, for example. +//! use crate::backend::maintenance_mode; use super::prelude::*; -/// Turn maintenance mode on/off. +/// Turn maintenance mode on/off, optionally for a single database. #[derive(Default)] pub struct MaintenanceMode { enable: bool, + database: Option, } #[async_trait] @@ -15,24 +23,33 @@ impl Command for MaintenanceMode { fn parse(sql: &str) -> Result { let parts = sql.split(" ").collect::>(); - match parts[..] { - ["maintenance", "on"] => Ok(Self { enable: true }), - ["maintenance", "off"] => Ok(Self { enable: false }), - _ => Err(Error::Syntax), - } + let (enable, database) = match parts[..] { + ["maintenance", "on"] => (true, None), + ["maintenance", "off"] => (false, None), + ["maintenance", "on", database] => (true, Some(database.to_string())), + ["maintenance", "off", database] => (false, Some(database.to_string())), + _ => return Err(Error::Syntax), + }; + + Ok(Self { enable, database }) } async fn execute(&self) -> Result, Error> { + let database = self.database.as_deref(); if self.enable { - maintenance_mode::start(); + maintenance_mode::start(database); } else { - maintenance_mode::stop(); + maintenance_mode::stop(database); } Ok(vec![]) } fn name(&self) -> String { - format!("MAINTENANCE {}", if self.enable { "ON" } else { "OFF" }) + let state = if self.enable { "ON" } else { "OFF" }; + match &self.database { + Some(database) => format!("MAINTENANCE {} {}", state, database), + None => format!("MAINTENANCE {}", state), + } } } diff --git a/pgdog/src/backend/maintenance_mode.rs b/pgdog/src/backend/maintenance_mode.rs index 34ffb4ced..d1e5fc312 100644 --- a/pgdog/src/backend/maintenance_mode.rs +++ b/pgdog/src/backend/maintenance_mode.rs @@ -1,45 +1,401 @@ -use std::sync::atomic::{AtomicBool, Ordering}; +//! Pause access to all/specific databases while we change the world. +//! +//! Maintenance mode is special: it's independent from the config +//! and will hold true during config changes, e.g. when some databases disappear, e.g., +//! replicas or shards are added/removed. +//! +//! This is useful when changing the sharding config online, for example. +//! +use std::{ + collections::HashMap, + future::{Future, IntoFuture}, + pin::Pin, + sync::Arc, +}; +use arc_swap::ArcSwap; use once_cell::sync::Lazy; -use tokio::sync::{futures::Notified, Notify}; +use parking_lot::Mutex; +use tokio::sync::broadcast; use tracing::warn; static MAINTENANCE_MODE: Lazy = Lazy::new(|| MaintenanceMode { - notify: Notify::new(), - on: AtomicBool::new(false), + state: ArcSwap::from_pointee(MaintenanceState::default()), + write_lock: Mutex::new(()), }); -pub(crate) fn waiter() -> Option> { - if !MAINTENANCE_MODE.on.load(Ordering::Relaxed) { - None - } else { - let notified = MAINTENANCE_MODE.notify.notified(); - if !MAINTENANCE_MODE.on.load(Ordering::Relaxed) { - None - } else { - Some(notified) - } +pub(crate) fn waiter(database: &str) -> Option { + MAINTENANCE_MODE.get_waiter(database) +} + +/// Future that resolves once a database leaves maintenance mode. +/// +/// Wraps the broadcast receiver so callers can simply `.await` it; it resolves +/// when the maintenance channel is closed (the sender is dropped by `stop`). +pub(crate) struct Waiter { + receiver: broadcast::Receiver<()>, + database: String, +} + +impl IntoFuture for Waiter { + type Output = (); + type IntoFuture = Pin + Send>>; + + fn into_future(mut self) -> Self::IntoFuture { + Box::pin(async move { + // Resolves when the channel is closed (sender dropped). + let _ = self.receiver.recv().await; + + // Re-check to avoid race between MAINTENANCE ON and MAINTENANCE ON. + if let Some(waiter) = MAINTENANCE_MODE.get_waiter(&self.database) { + let _ = waiter.await; + } + }) } } -pub fn start() { - MAINTENANCE_MODE.on.store(true, Ordering::Relaxed); - warn!("maintenance mode is on"); +pub fn start(database: Option<&str>) { + match database { + Some(database) => { + MAINTENANCE_MODE.add(database); + warn!("maintenance mode is on for database \"{}\"", database); + } + None => { + MAINTENANCE_MODE.add_all(); + warn!("maintenance mode is on for all databases"); + } + } } -pub fn stop() { - MAINTENANCE_MODE.on.store(false, Ordering::Relaxed); - MAINTENANCE_MODE.notify.notify_waiters(); - warn!("maintenance mode is off"); +pub fn stop(database: Option<&str>) { + match database { + Some(database) => { + MAINTENANCE_MODE.remove(database); + warn!("maintenance mode is off for database \"{}\"", database); + } + None => { + MAINTENANCE_MODE.remove_all(); + warn!("maintenance mode is off for all databases"); + } + } } #[cfg(test)] -pub fn is_on() -> bool { - MAINTENANCE_MODE.on.load(Ordering::Relaxed) +pub fn is_on(database: &str) -> bool { + MAINTENANCE_MODE.paused(database) } #[derive(Debug)] struct MaintenanceMode { - notify: Notify, - on: AtomicBool, + state: ArcSwap, + write_lock: Mutex<()>, +} + +#[derive(Clone, Debug, Default)] +struct MaintenanceState { + // Per-database maintenance mode. + databases: HashMap>, + // Global maintenance mode (all databases, current and future ones). + all: Option>, +} + +impl MaintenanceMode { + /// Check whether the given database is currently in maintenance mode. + #[cfg(test)] + #[inline] + fn paused(&self, database: &str) -> bool { + self.get_waiter(database).is_some() + } + + /// Get a [`Waiter`] that resolves once the database leaves maintenance + /// mode, or `None` if it isn't in maintenance mode right now. + /// + /// # Arguments + /// + /// * `database`: name of the database to wait for. + /// + fn get_waiter(&self, database: &str) -> Option { + let state = self.state.load(); + + if state.databases.is_empty() && state.all.is_none() { + return None; + } + + match state.databases.get(database) { + Some(sender) => Some(Waiter { + receiver: sender.subscribe(), + database: database.to_string(), + }), + None => state.all.as_ref().map(|sender| Waiter { + receiver: sender.subscribe(), + database: database.to_string(), + }), + } + } + + /// Put a single database into maintenance mode. + /// + /// # Arguments + /// + /// * `database`: name of the database to pause. + /// + fn add(&self, database: &str) { + let _guard = self.write_lock.lock(); + let state = self.state.load(); + let mut next = MaintenanceState::clone(&state); + + // Global maintenance covers individual databases. + if next.all.is_some() { + return; + } + + // Keep the existing channel if already paused, so current waiters + // stay valid. + next.databases + .entry(database.to_string()) + .or_insert_with(|| broadcast::channel(1).0); + + self.state.store(Arc::new(next)); + } + + /// Take a single database out of maintenance mode and wake its waiters by + /// dropping (closing) its channel. + /// + /// # Arguments + /// + /// * `database`: name of the database to resume. + /// + fn remove(&self, database: &str) { + let _guard = self.write_lock.lock(); + let state = self.state.load(); + let mut next = MaintenanceState::clone(&state); + + next.databases.remove(database); + + self.state.store(Arc::new(next)); + } + + /// Put every configured database into maintenance mode. + fn add_all(&self) { + let _guard = self.write_lock.lock(); + let state = self.state.load(); + + let mut next = MaintenanceState::clone(&state); + + if next.all.is_none() { + next.all = Some(broadcast::channel(1).0); + } + + self.state.store(Arc::new(next)); + } + + /// Take every database out of maintenance mode, including ones paused + /// individually, and wake their waiters by dropping (closing) the channels. + fn remove_all(&self) { + let _guard = self.write_lock.lock(); + let state = self.state.load(); + + let mut next = MaintenanceState::clone(&state); + + next.all = None; + next.databases.clear(); + + self.state.store(Arc::new(next)); + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::time::Duration; + use tokio::time::{sleep, timeout}; + + /// Fresh, isolated instance so tests don't share the global singleton. + fn maintenance() -> MaintenanceMode { + MaintenanceMode { + state: ArcSwap::from_pointee(MaintenanceState::default()), + write_lock: Mutex::new(()), + } + } + + #[test] + fn nothing_paused_by_default() { + let m = maintenance(); + assert!(!m.paused("anything")); + assert!(m.get_waiter("anything").is_none()); + } + + #[test] + fn pause_single_database() { + let m = maintenance(); + m.add("one"); + + assert!(m.paused("one")); + assert!(!m.paused("two")); + assert_eq!(m.state.load().databases.len(), 1); + + // Only the paused database gets a waiter. + assert!(m.get_waiter("one").is_some()); + assert!(m.get_waiter("two").is_none()); + + m.remove("one"); + assert!(!m.paused("one")); + assert_eq!(m.state.load().databases.len(), 0); + } + + #[test] + fn pause_is_idempotent() { + let m = maintenance(); + m.add("one"); + m.add("one"); + assert_eq!(m.state.load().databases.len(), 1); + + m.remove("one"); + assert_eq!(m.state.load().databases.len(), 0); + // Removing again is a no-op. + m.remove("one"); + assert_eq!(m.state.load().databases.len(), 0); + } + + #[tokio::test] + async fn pause_all_databases() { + // `add_all` pauses every database in the loaded config ("pgdog"). + crate::config::load_test(); + let m = maintenance(); + m.add_all(); + + assert!(m.paused("pgdog")); + assert!(m.get_waiter("pgdog").is_some()); + + m.remove_all(); + assert!(!m.paused("pgdog")); + } + + #[tokio::test] + async fn remove_all_clears_everything() { + // `remove_all` clears the whole set, including databases paused + // individually. + crate::config::load_test(); + let m = maintenance(); + m.add("other"); + m.add_all(); // pauses the configured "pgdog" + + assert!(m.paused("pgdog")); + assert!(m.paused("other")); + + m.remove_all(); + assert!(!m.paused("pgdog")); + assert!(!m.paused("other")); + } + + #[tokio::test] + async fn waiter_pending_until_resumed() { + let m = maintenance(); + m.add("one"); + let waiter = m.get_waiter("one").expect("database is paused"); + + // While paused, the waiter does not resolve. + let pending = timeout(Duration::from_millis(100), waiter.into_future()).await; + assert!(pending.is_err(), "waiter should still be pending"); + } + + #[tokio::test] + async fn waiter_resolves_on_remove() { + let m = maintenance(); + m.add("one"); + let waiter = m.get_waiter("one").expect("database is paused"); + + m.remove("one"); + + timeout(Duration::from_secs(1), waiter.into_future()) + .await + .expect("waiter should resolve once the database is resumed"); + } + + #[tokio::test] + async fn waiter_resolves_on_remove_all() { + crate::config::load_test(); + let m = maintenance(); + m.add_all(); + let waiter = m + .get_waiter("pgdog") + .expect("configured database is paused"); + + m.remove_all(); + + timeout(Duration::from_secs(1), waiter.into_future()) + .await + .expect("waiter should resolve once maintenance is lifted"); + } + + #[tokio::test] + async fn individual_maintenance_is_ignored_while_all_is_on() { + let m = maintenance(); + m.add_all(); + m.add("db"); // no-op: `all` already covers "db" + + // "db" is paused, but only through the global `all` channel. + assert!(m.state.load().databases.is_empty()); + let waiter = m + .get_waiter("db") + .expect("db is paused under all-maintenance"); + + // Resuming "db" individually must NOT release it while `all` is on. + m.remove("db"); + let pending = timeout(Duration::from_millis(100), waiter.into_future()).await; + assert!( + pending.is_err(), + "db waiter must stay blocked while all-maintenance is on" + ); + } + + #[tokio::test] + async fn waiter_survives_re_pause() { + // Re-pausing an already-paused database must keep the existing channel, + // so a waiter created before the second `add` still resolves on remove. + let m = maintenance(); + m.add("one"); + let waiter = m.get_waiter("one").expect("database is paused"); + m.add("one"); // must not replace the channel + + m.remove("one"); + + timeout(Duration::from_secs(1), waiter.into_future()) + .await + .expect("waiter should resolve even after a re-pause"); + } + + #[tokio::test] + async fn individual_resume_does_not_release_under_global() { + // Reverse ordering: a database is paused individually first, then global + // maintenance is turned on. Resuming the database individually must not + // release its waiter while global maintenance is still on — the waiter + // re-checks and re-parks on the global channel. + // + // Drives the global instance because the waiter's re-check consults it. + MAINTENANCE_MODE.remove_all(); // clean slate + + MAINTENANCE_MODE.add("db"); + let waiter = MAINTENANCE_MODE + .get_waiter("db") + .expect("db is paused individually"); + MAINTENANCE_MODE.add_all(); + + let handle = tokio::spawn(async move { waiter.await }); + + // Individual resume must not wake it while global maintenance is on. + MAINTENANCE_MODE.remove("db"); + sleep(Duration::from_millis(50)).await; + assert!( + !handle.is_finished(), + "waiter must stay blocked while global maintenance is on" + ); + + // Lifting global maintenance finally releases it. + MAINTENANCE_MODE.remove_all(); + timeout(Duration::from_secs(1), handle) + .await + .expect("waiter should resolve once global maintenance is lifted") + .unwrap(); + } } diff --git a/pgdog/src/backend/replication/logical/orchestrator.rs b/pgdog/src/backend/replication/logical/orchestrator.rs index 0630bad35..4079b1dbf 100644 --- a/pgdog/src/backend/replication/logical/orchestrator.rs +++ b/pgdog/src/backend/replication/logical/orchestrator.rs @@ -312,7 +312,7 @@ impl ReplicationWaiter { ); // Pause traffic. - maintenance_mode::start(); + maintenance_mode::start(None); // Cancel any running queries. cancel_all(&self.orchestrator.source.identifier().database).await?; @@ -408,7 +408,7 @@ impl ReplicationWaiter { match cutover_reason { CutoverAction::Go(CutoverReason::Timeout) => { if cutover_timeout_action == CutoverTimeoutAction::Abort { - maintenance_mode::stop(); + maintenance_mode::stop(None); warn!("[cutover] abort timeout reached, resuming traffic"); return Err(Error::AbortTimeout); } else { @@ -474,7 +474,7 @@ impl ReplicationWaiter { info!("[cutover] complete, resuming traffic"); // Point traffic to the other database and resume. - maintenance_mode::stop(); + maintenance_mode::stop(None); cutover_state(CutoverState::Complete); @@ -487,7 +487,7 @@ macro_rules! ok_or_abort { match $expr { Ok(res) => res, Err(err) => { - maintenance_mode::stop(); + maintenance_mode::stop(None); cutover_state(CutoverState::Abort { error: err.to_string(), }); @@ -534,8 +534,8 @@ mod tests { #[tokio::test] async fn test_wait_for_replication_exits_when_lag_below_threshold() { // Ensure maintenance mode is off at start - maintenance_mode::stop(); - assert!(!maintenance_mode::is_on()); + maintenance_mode::stop(None); + assert!(!maintenance_mode::is_on("")); // Will return true because all databases are paused. let mut config = ConfigAndUsers::default(); config.config.general.cutover_traffic_stop_threshold = 1000; @@ -557,11 +557,11 @@ mod tests { assert!(result.is_ok()); // Maintenance mode should be on after wait_for_replication - assert!(maintenance_mode::is_on()); + assert!(maintenance_mode::is_on("")); // Clean up maintenance mode - maintenance_mode::stop(); - assert!(!maintenance_mode::is_on()); + maintenance_mode::stop(None); + assert!(!maintenance_mode::is_on("")); } #[tokio::test] diff --git a/pgdog/src/frontend/client/mod.rs b/pgdog/src/frontend/client/mod.rs index 92a8e3ab5..47601c767 100644 --- a/pgdog/src/frontend/client/mod.rs +++ b/pgdog/src/frontend/client/mod.rs @@ -1,4 +1,7 @@ -//! Frontend client. +//! PostgreSQL client. +//! +//! Entrypoint for client/server interactions. +//! use std::net::SocketAddr; use std::sync::Arc; @@ -34,71 +37,74 @@ use crate::util::user_database_from_params; pub mod query_engine; pub mod sticky; pub mod timeouts; +pub mod transaction_type; pub(crate) use sticky::Sticky; +pub use transaction_type::TransactionType; -/// Frontend client. +/// PostgreSQL client. +/// +/// It thinks it's talking to a real Postgres server, but actually it's talking to PgDog :-). +/// #[derive(Debug)] pub struct Client { + // Client IP. addr: SocketAddr, + // Client socket. stream: Stream, + // Client unique identifier. Randomly generated + // for each client. id: BackendKeyData, - #[allow(dead_code)] - connect_params: Parameters, + // Client startup parameters. Keeps track of any parameters + // the client changes at runtime with `SET` as well. params: Parameters, + // Process-global communication primitives used for clients + // to talk to each other, e.g. to track their own state. comms: ClientComms, + // Client is connected to the admin database. admin: bool, + // Client is streaming data via replication, and not running + // regular queries. We skip all the fancy stuff here, i.e., + // no query parsing, routing, etc. + // + // Don't expect sharding to work if this is what the client is doing. streaming: bool, + // Client prepared statements cache. prepared_statements: PreparedStatements, + // Client transaction state. transaction: Option, + // Current timeouts to use for client/server communication. + // These change based on client state, e.g. if client is running query, + // the `query_timeout` is active, and if the client is idle, the `client_idle_timeout` is. timeouts: Timeouts, + // Stateful buffer containing the current whole client request. + // This can be a query or just a `Parse` and `Flush`, but in either case, the client + // will expect a response immediately and we need to handle it. client_request: ClientRequest, + // Raw buffer of messages the client sent. We keep them here to avoid memory allocations + // down the line (using [`bytes::Bytes`]). stream_buffer: MessageBuffer, + // Settings that override query routing behavior, e.g., client wants to talk + // to replicas only. sticky: Sticky, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum TransactionType { - ReadOnly, - #[default] - ReadWrite, - ErrorReadWrite, - ErrorReadOnly, -} - -impl TransactionType { - pub fn read_only(&self) -> bool { - matches!(self, Self::ReadOnly) - } - - pub fn write(&self) -> bool { - !self.read_only() - } - - pub fn error(&self) -> bool { - matches!(self, Self::ErrorReadWrite | Self::ErrorReadOnly) - } -} - -impl MemoryUsage for Client { - #[inline] - fn memory_usage(&self) -> usize { - std::mem::size_of::() - + std::mem::size_of::() - + std::mem::size_of::() - + self.connect_params.memory_usage() - + self.params.memory_usage() - + std::mem::size_of::() - + std::mem::size_of::() * 5 - + self.prepared_statements.memory_used() - + std::mem::size_of::() - + self.stream_buffer.capacity() - + self.client_request.memory_usage() - } + /// Client database. + database: String, } impl Client { - /// Create new frontend client from the given TCP stream. + /// Create new frontend client from the a TCP socket. + /// + /// The client already sent a valid Startup message and negotiated TLS. + /// + /// # Parameters + /// + /// - `stream`: TCP stream. + /// - `params`: Client parameters extracted from the [`crate::net::Startup`] message. + /// - `addr`: TCP IP. + /// - `config`: Currently loaded `pgdog.toml` and `users.toml`. + /// - `protocol_version`: The version of the PostgreSQL protocol used by the client. This is typically 3.0, but can be 3.2 + /// for more modern clients. + /// pub async fn spawn( stream: Stream, params: Parameters, @@ -372,7 +378,7 @@ impl Client { client_request: ClientRequest::default(), stream_buffer: MessageBuffer::new(config.config.memory.message_buffer), sticky: Sticky::from_params(¶ms), - connect_params: params, + database: database.to_string(), })) } @@ -396,7 +402,6 @@ impl Client { comms: ClientComms::new(&id), streaming: false, prepared_statements, - connect_params: connect_params.clone(), admin: false, transaction: None, timeouts: Timeouts::from_config(&config().config.general), @@ -404,6 +409,7 @@ impl Client { stream_buffer: MessageBuffer::new(4096), sticky: Sticky::from_params(&connect_params), params: connect_params, + database: "pgdog".to_string(), } } @@ -512,7 +518,7 @@ impl Client { async fn client_messages(&mut self, query_engine: &mut QueryEngine) -> Result<(), Error> { // Check maintenance mode. if !self.in_transaction() && !self.admin { - if let Some(waiter) = maintenance_mode::waiter() { + if let Some(waiter) = maintenance_mode::waiter(&self.database) { let state = query_engine.get_state(); query_engine.set_state(State::Waiting); waiter.await; @@ -657,6 +663,22 @@ impl Client { } } +impl MemoryUsage for Client { + #[inline] + fn memory_usage(&self) -> usize { + std::mem::size_of::() + + std::mem::size_of::() + + std::mem::size_of::() + + self.params.memory_usage() + + std::mem::size_of::() + + std::mem::size_of::() * 5 + + self.prepared_statements.memory_used() + + std::mem::size_of::() + + self.stream_buffer.capacity() + + self.client_request.memory_usage() + } +} + #[cfg(test)] pub mod test; diff --git a/pgdog/src/frontend/client/transaction_type.rs b/pgdog/src/frontend/client/transaction_type.rs new file mode 100644 index 000000000..bc98f6435 --- /dev/null +++ b/pgdog/src/frontend/client/transaction_type.rs @@ -0,0 +1,22 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum TransactionType { + ReadOnly, + #[default] + ReadWrite, + ErrorReadWrite, + ErrorReadOnly, +} + +impl TransactionType { + pub fn read_only(&self) -> bool { + matches!(self, Self::ReadOnly) + } + + pub fn write(&self) -> bool { + !self.read_only() + } + + pub fn error(&self) -> bool { + matches!(self, Self::ErrorReadWrite | Self::ErrorReadOnly) + } +} diff --git a/pgdog/src/net/parameter.rs b/pgdog/src/net/parameter.rs index 1502d0397..dc85eee2f 100644 --- a/pgdog/src/net/parameter.rs +++ b/pgdog/src/net/parameter.rs @@ -193,19 +193,6 @@ impl MemoryUsage for Parameters { } } -impl From> for Parameters { - fn from(value: BTreeMap) -> Self { - let hash = Self::compute_hash(&value); - Self { - params: value, - hash, - transaction_params: BTreeMap::new(), - transaction_local_params: BTreeMap::new(), - reset_params: BTreeMap::new(), - } - } -} - impl Parameters { /// Lowercase all param names. pub fn insert( @@ -345,12 +332,20 @@ impl Parameters { } pub fn tracked(&self) -> Parameters { - self.params + let params = self + .params .iter() .filter(|(k, _)| !UNTRACKED_PARAMS.contains(k)) .map(|(k, v)| (k.clone(), v.clone())) - .collect::>() - .into() + .collect::>(); + + let hash = Self::compute_hash(¶ms); + + Self { + params, + hash, + ..Default::default() + } } /// Merge params from self into other, generating the queries diff --git a/pgdog/src/util.rs b/pgdog/src/util.rs index 51d5d0ffe..962ed1c56 100644 --- a/pgdog/src/util.rs +++ b/pgdog/src/util.rs @@ -192,6 +192,17 @@ pub fn format_bytes(bytes: u64) -> String { } /// Get user and database parameters. +/// +/// These parameters are standard and defined by the Postgres protocol. +/// +/// # Arguments +/// +/// - `params`: Client parameters extracted from the [`crate::net::Startup`] message. +/// +/// # Return +/// +/// Tuple of (user, database). +/// pub fn user_database_from_params(params: &Parameters) -> (&str, &str) { let user = params.get_default("user", "postgres"); let database = params.get_default("database", user);