diff --git a/bb8/Cargo.toml b/bb8/Cargo.toml index ce370ef..45b1aa3 100644 --- a/bb8/Cargo.toml +++ b/bb8/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bb8" -version = "0.8.2" +version = "0.8.3" edition = "2021" rust-version = "1.63" description = "Full-featured async (tokio-based) connection pool (like r2d2)" diff --git a/bb8/src/api.rs b/bb8/src/api.rs index 9ecad68..8b99798 100644 --- a/bb8/src/api.rs +++ b/bb8/src/api.rs @@ -63,6 +63,7 @@ impl Pool { Ok(PooledConnection { conn: self.get().await?.take(), pool: Cow::Owned(self.inner.clone()), + state: ConnectionState::Present, }) } @@ -372,6 +373,7 @@ where { pool: Cow<'a, PoolInner>, conn: Option>, + pub(crate) state: ConnectionState, } impl<'a, M> PooledConnection<'a, M> @@ -382,14 +384,12 @@ where Self { pool: Cow::Borrowed(pool), conn: Some(conn), + state: ConnectionState::Present, } } - pub(crate) fn drop_invalid(mut self) { - let _ = self.conn.take(); - } - - pub(crate) fn take(&mut self) -> Option> { + pub(crate) fn take(mut self) -> Option> { + self.state = ConnectionState::Extracted; self.conn.take() } } @@ -429,10 +429,24 @@ where M: ManageConnection, { fn drop(&mut self) { - self.pool.as_ref().put_back(self.conn.take()); + if let ConnectionState::Extracted = self.state { + return; + } + + debug_assert!(self.conn.is_some(), "incorrect state {:?}", self.state); + if let Some(conn) = self.conn.take() { + self.pool.as_ref().put_back(conn, self.state); + } } } +#[derive(Debug, Clone, Copy)] +pub(crate) enum ConnectionState { + Present, + Extracted, + Invalid, +} + /// bb8's error type. #[derive(Debug, Clone, PartialEq, Eq)] pub enum RunError { diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index 7f22516..8ba6f53 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -9,7 +9,7 @@ use futures_util::TryFutureExt; use tokio::spawn; use tokio::time::{interval_at, sleep, timeout, Interval}; -use crate::api::{Builder, ManageConnection, PooledConnection, RunError}; +use crate::api::{Builder, ConnectionState, ManageConnection, PooledConnection, RunError}; use crate::internals::{Approval, ApprovalIter, Conn, SharedPool, State}; pub(crate) struct PoolInner @@ -89,6 +89,10 @@ where loop { let (conn, approvals) = self.inner.pop(); self.spawn_replenishing_approvals(approvals); + + // Cancellation safety: make sure to wrap the connection in a `PooledConnection` + // before allowing the code to hit an `await`, so we don't lose the connection. + let mut conn = match conn { Some(conn) => PooledConnection::new(self, conn), None => { @@ -105,7 +109,7 @@ where Ok(()) => return Ok(conn), Err(e) => { self.inner.forward_error(e); - conn.drop_invalid(); + conn.state = ConnectionState::Invalid; continue; } } @@ -125,19 +129,16 @@ where } /// Return connection back in to the pool - pub(crate) fn put_back(&self, conn: Option>) { - let conn = conn.and_then(|mut conn| { - if !self.inner.manager.has_broken(&mut conn.conn) { - Some(conn) - } else { - None - } - }); + pub(crate) fn put_back(&self, mut conn: Conn, state: ConnectionState) { + debug_assert!( + !matches!(state, ConnectionState::Extracted), + "handled in caller" + ); let mut locked = self.inner.internals.lock(); - match conn { - Some(conn) => locked.put(conn, None, self.inner.clone()), - None => { + match (state, self.inner.manager.has_broken(&mut conn.conn)) { + (ConnectionState::Present, false) => locked.put(conn, None, self.inner.clone()), + (_, _) => { let approvals = locked.dropped(1, &self.inner.statics); self.spawn_replenishing_approvals(approvals); } diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index 9794453..ed153f0 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -76,8 +76,16 @@ where pool: Arc>, ) { if approval.is_some() { - self.pending_conns -= 1; - self.num_conns += 1; + #[cfg(debug_assertions)] + { + self.pending_conns -= 1; + self.num_conns += 1; + } + #[cfg(not(debug_assertions))] + { + self.pending_conns = self.pending_conns.saturating_sub(1); + self.num_conns = self.num_conns.saturating_add(1); + } } // Queue it in the idle queue @@ -91,35 +99,39 @@ where } pub(crate) fn connect_failed(&mut self, _: Approval) { - self.pending_conns -= 1; + #[cfg(debug_assertions)] + { + self.pending_conns -= 1; + } + #[cfg(not(debug_assertions))] + { + self.pending_conns = self.pending_conns.saturating_sub(1); + } } pub(crate) fn dropped(&mut self, num: u32, config: &Builder) -> ApprovalIter { - self.num_conns -= num; + #[cfg(debug_assertions)] + { + self.num_conns -= num; + } + #[cfg(not(debug_assertions))] + { + self.num_conns = self.num_conns.saturating_sub(num); + } + self.wanted(config) } pub(crate) fn wanted(&mut self, config: &Builder) -> ApprovalIter { let available = self.conns.len() as u32 + self.pending_conns; let min_idle = config.min_idle.unwrap_or(0); - let wanted = if available < min_idle { - min_idle - available - } else { - 0 - }; - + let wanted = min_idle.saturating_sub(available); self.approvals(config, wanted) } fn approvals(&mut self, config: &Builder, num: u32) -> ApprovalIter { let current = self.num_conns + self.pending_conns; - let allowed = if current < config.max_size { - config.max_size - current - } else { - 0 - }; - - let num = min(num, allowed); + let num = min(num, config.max_size.saturating_sub(current)); self.pending_conns += num; ApprovalIter { num: num as usize } }