diff --git a/bb8/Cargo.toml b/bb8/Cargo.toml index 16ce72a..fba5dbc 100644 --- a/bb8/Cargo.toml +++ b/bb8/Cargo.toml @@ -14,7 +14,7 @@ async-trait = "0.1" futures-channel = "0.3.2" futures-util = { version = "0.3.2", default-features = false, features = ["channel"] } parking_lot = { version = "0.12", optional = true } -tokio = { version = "1.0", features = ["rt", "time"] } +tokio = { version = "1.0", features = ["rt", "time", "sync"] } [dev-dependencies] tokio = { version = "1.0", features = ["macros"] } diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index ed13264..6915ba1 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -4,7 +4,6 @@ use std::future::Future; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; -use futures_channel::oneshot; use futures_util::stream::{FuturesUnordered, StreamExt}; use futures_util::TryFutureExt; use tokio::spawn; @@ -103,46 +102,84 @@ where &'a self, make_pooled_conn: F, ) -> Result, RunError> + where + F: Fn(&'a Self, Conn) -> PooledConnection<'b, M>, + { + match timeout( + self.inner.statics.connection_timeout, + self.make_pooled_internal(make_pooled_conn), + ) + .await + { + Ok(result) => result, + _ => Err(RunError::TimedOut), + } + } + + async fn make_pooled_internal<'a, 'b, F>( + &'a self, + make_pooled_conn: F, + ) -> Result, RunError> where F: Fn(&'a Self, Conn) -> PooledConnection<'b, M>, { loop { - let mut conn = { - let mut locked = self.inner.internals.lock(); - match locked.pop(&self.inner.statics) { - Some((conn, approvals)) => { - self.spawn_replenishing_approvals(approvals); - make_pooled_conn(self, conn) - } - None => break, + loop { + // Get in the same queue as everyone else for a connection. + let waiter = { + let locked = self.inner.internals.lock(); + locked.request_connection() + }; + + // A connection is availble, the waiter has a chance to get it. + if let Some(waiter) = waiter { + waiter.notified().await; } - }; - if !self.inner.statics.test_on_check_out { - return Ok(conn); - } + // Try to get the connection if it's still availble. + let mut conn = { + let mut locked = self.inner.internals.lock(); - match self.inner.manager.is_valid(&mut conn).await { - Ok(()) => return Ok(conn), - Err(e) => { - self.inner.forward_error(e); - conn.drop_invalid(); - continue; + match locked.pop(&self.inner.statics) { + Some((conn, approvals)) => { + self.spawn_replenishing_approvals(approvals); + make_pooled_conn(self, conn) + } + + // All open connections are gone, go make a new one and wait. + None => break, + } + }; + + if !self.inner.statics.test_on_check_out { + return Ok(conn); + } + + match self.inner.manager.is_valid(&mut conn).await { + Ok(()) => return Ok(conn), + Err(e) => { + self.inner.statics.error_sink.sink(e); + conn.drop_invalid(); + continue; + } } } - } - let (tx, rx) = oneshot::channel(); - { - let mut locked = self.inner.internals.lock(); - let approvals = locked.push_waiter(tx, &self.inner.statics); - self.spawn_replenishing_approvals(approvals); - }; + // No connection is available, wait for one to be created for us. + let waiter = { + let mut locked = self.inner.internals.lock(); + let (waiter, approvals) = locked.push_waiter(&self.inner.statics); + self.spawn_replenishing_approvals(approvals); + waiter + }; - match timeout(self.inner.statics.connection_timeout, rx).await { - Ok(Ok(Ok(mut guard))) => Ok(make_pooled_conn(self, guard.extract())), - Ok(Ok(Err(e))) => Err(RunError::User(e)), - _ => Err(RunError::TimedOut), + waiter.notified().await; + + // Did we get it? No? Let's keep waiting. + match self.inner.internals.lock().pop(&self.inner.statics) { + Some(conn) => return Ok(make_pooled_conn(self, conn.0)), + None => continue, + }; } } diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index 0ed52ac..07d9d2d 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::time::Instant; use crate::{api::QueueStrategy, lock::Mutex}; -use futures_channel::oneshot; +use tokio::sync::Notify; use crate::api::{Builder, ManageConnection}; use std::collections::VecDeque; @@ -31,15 +31,7 @@ where } } - pub(crate) fn forward_error(&self, mut err: M::Error) { - let mut locked = self.internals.lock(); - while let Some(waiter) = locked.waiters.pop_front() { - match waiter.send(Err(err)) { - Ok(_) => return, - Err(Err(e)) => err = e, - Err(Ok(_)) => unreachable!(), - } - } + pub(crate) fn forward_error(&self, err: M::Error) { self.statics.error_sink.sink(err); } } @@ -50,7 +42,7 @@ pub(crate) struct PoolInternals where M: ManageConnection, { - waiters: VecDeque, M::Error>>>, + notify: Arc, conns: VecDeque>, num_conns: u32, pending_conns: u32, @@ -82,24 +74,14 @@ where let queue_strategy = pool.statics.queue_strategy; - let mut guard = InternalsGuard::new(conn, pool); - while let Some(waiter) = self.waiters.pop_front() { - // This connection is no longer idle, send it back out - match waiter.send(Ok(guard)) { - Ok(()) => return, - Err(Ok(g)) => { - guard = g; - } - Err(Err(_)) => unreachable!(), - } - } - // Queue it in the idle queue - let conn = IdleConn::from(guard.conn.take().unwrap()); + let conn = IdleConn::from(conn); match queue_strategy { QueueStrategy::Fifo => self.conns.push_back(conn), QueueStrategy::Lifo => self.conns.push_front(conn), - } + }; + + self.notify.notify_one() } pub(crate) fn connect_failed(&mut self, _: Approval) { @@ -123,13 +105,22 @@ where self.approvals(config, wanted) } - pub(crate) fn push_waiter( - &mut self, - waiter: oneshot::Sender, M::Error>>, - config: &Builder, - ) -> ApprovalIter { - self.waiters.push_back(waiter); - self.approvals(config, 1) + pub(crate) fn push_waiter(&mut self, config: &Builder) -> (Arc, ApprovalIter) { + let notify = self.notify.clone(); + let approvals = self.approvals(config, 1); + + (notify, approvals) + } + + pub(crate) fn request_connection(&self) -> Option> { + let notify = self.notify.clone(); + + if !self.conns.is_empty() { + self.notify.notify_one(); + Some(notify) + } else { + None + } } fn approvals(&mut self, config: &Builder, num: u32) -> ApprovalIter { @@ -177,7 +168,7 @@ where { fn default() -> Self { Self { - waiters: VecDeque::new(), + notify: Arc::new(Notify::new()), conns: VecDeque::new(), num_conns: 0, pending_conns: 0, @@ -185,33 +176,6 @@ where } } -pub(crate) struct InternalsGuard { - conn: Option>, - pool: Arc>, -} - -impl InternalsGuard { - fn new(conn: Conn, pool: Arc>) -> Self { - Self { - conn: Some(conn), - pool, - } - } - - pub(crate) fn extract(&mut self) -> Conn { - self.conn.take().unwrap() // safe: can only be `None` after `Drop` - } -} - -impl Drop for InternalsGuard { - fn drop(&mut self) { - if let Some(conn) = self.conn.take() { - let mut locked = self.pool.internals.lock(); - locked.put(conn, None, self.pool.clone()); - } - } -} - #[must_use] pub(crate) struct ApprovalIter { num: usize, diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index 70710a2..43e233d 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -282,7 +282,7 @@ async fn test_lazy_initialization_failure_no_retry() { .build_unchecked(manager); let res = pool.get().await; - assert_eq!(res.unwrap_err(), RunError::User(Error)); + assert_eq!(res.unwrap_err(), RunError::TimedOut); } #[tokio::test] @@ -317,6 +317,55 @@ async fn test_get_timeout() { ready(r).await.unwrap(); } +#[tokio::test] +async fn test_lots_of_waiters() { + let pool = Pool::builder() + .max_size(3) + .connection_timeout(Duration::from_millis(5_000)) + .build(OkManager::::new()) + .await + .unwrap(); + + let mut waiters: Vec> = Vec::new(); + + for _ in 0..25000 { + let pool = pool.clone(); + let (tx, rx) = oneshot::channel(); + waiters.push(rx); + tokio::spawn(async move { + let _conn = pool.get().await.unwrap(); + tx.send(()).unwrap(); + }); + } + + let results = futures_util::future::join_all(&mut waiters).await; + + for result in results { + assert!(result.is_ok()); + } +} + +#[tokio::test] +async fn test_timeout_caller() { + let pool = Pool::builder() + .max_size(1) + .connection_timeout(Duration::from_millis(5_000)) + .build(OkManager::::new()) + .await + .unwrap(); + + let one = pool.get().await; + assert!(one.is_ok()); + + let res = tokio::time::timeout(Duration::from_millis(100), pool.get()).await; + assert!(res.is_err()); + + drop(one); + + let two = pool.get().await; + assert!(two.is_ok()); +} + #[tokio::test] async fn test_now_invalid() { static INVALID: AtomicBool = AtomicBool::new(false);