Skip to content

Commit a80460e

Browse files
committed
Resurrect on_release support from #89 by spawning
1 parent eadfb32 commit a80460e

File tree

4 files changed

+158
-13
lines changed

4 files changed

+158
-13
lines changed

bb8/src/api.rs

+12-7
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,18 @@ pub trait CustomizeConnection<C: Send + 'static, E: 'static>:
273273
/// Called with connections immediately after they are returned from
274274
/// `ManageConnection::connect`.
275275
///
276-
/// The default implementation simply returns `Ok(())`.
277-
///
278-
/// # Errors
276+
/// The default implementation simply returns `Ok(())`. Any errors will be forwarded to the
277+
/// configured error sink.
278+
async fn on_acquire(&self, _connection: &mut C) -> Result<(), E> {
279+
Ok(())
280+
}
281+
282+
/// Called with connections before they're returned to the connection pool.
279283
///
280-
/// If this method returns an error, the connection will be discarded.
284+
/// The default implementation simply returns `Ok(())`. Any errors will be forwarded to the
285+
/// configured error sink.
281286
#[allow(unused_variables)]
282-
async fn on_acquire(&self, connection: &mut C) -> Result<(), E> {
287+
async fn on_release(&'_ self, _connection: &'_ mut C) -> Result<(), E> {
283288
Ok(())
284289
}
285290
}
@@ -304,8 +309,8 @@ where
304309
}
305310
}
306311

307-
pub(crate) fn drop_invalid(mut self) {
308-
let _ = self.conn.take();
312+
pub(crate) fn extract(mut self) -> Conn<M::Connection> {
313+
self.conn.take().unwrap()
309314
}
310315
}
311316

bb8/src/inner.rs

+23-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ where
102102
match self.inner.manager.is_valid(&mut conn).await {
103103
Ok(()) => return Ok(conn),
104104
Err(_) => {
105-
conn.drop_invalid();
105+
self.on_release_connection(conn.extract());
106+
// Once we've extracted the connection, the `Drop` impl for `PooledConnection`
107+
// will call `put_back(None)`, so we don't need to do anything else here.
106108
continue;
107109
}
108110
}
@@ -133,6 +135,7 @@ where
133135
if !self.inner.manager.has_broken(&mut conn.conn) {
134136
Some(conn)
135137
} else {
138+
self.on_release_connection(conn);
136139
None
137140
}
138141
});
@@ -147,6 +150,25 @@ where
147150
}
148151
}
149152

153+
fn on_release_connection(&self, mut conn: Conn<M::Connection>) {
154+
if self.inner.statics.connection_customizer.is_none() {
155+
return;
156+
}
157+
158+
let pool = self.inner.clone();
159+
spawn(async move {
160+
let customizer = match pool.statics.connection_customizer.as_ref() {
161+
Some(customizer) => customizer,
162+
None => return,
163+
};
164+
165+
let future = customizer.on_release(&mut conn.conn);
166+
if let Err(e) = future.await {
167+
pool.statics.error_sink.sink(e);
168+
}
169+
});
170+
}
171+
150172
/// Returns information about the current state of the pool.
151173
pub(crate) fn state(&self) -> State {
152174
self.inner.internals.lock().state()

bb8/src/internals.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,12 @@ where
168168
}
169169

170170
pub(crate) struct InternalsGuard<M: ManageConnection> {
171-
conn: Option<Conn<M::Connection>>,
172-
pool: Arc<SharedPool<M>>,
171+
pub(crate) conn: Option<Conn<M::Connection>>,
172+
pub(crate) pool: Arc<SharedPool<M>>,
173173
}
174174

175175
impl<M: ManageConnection> InternalsGuard<M> {
176-
fn new(conn: Conn<M::Connection>, pool: Arc<SharedPool<M>>) -> Self {
176+
pub(crate) fn new(conn: Conn<M::Connection>, pool: Arc<SharedPool<M>>) -> Self {
177177
Self {
178178
conn: Some(conn),
179179
pool,

bb8/tests/test.rs

+120-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::iter::FromIterator;
55
use std::marker::PhantomData;
66
use std::pin::Pin;
77
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8-
use std::sync::Mutex;
8+
use std::sync::{Arc, Mutex};
99
use std::task::Poll;
1010
use std::time::Duration;
1111
use std::{error, fmt, mem};
@@ -14,7 +14,7 @@ use async_trait::async_trait;
1414
use futures_channel::oneshot;
1515
use futures_util::future::{err, lazy, ok, pending, ready, try_join_all, FutureExt};
1616
use futures_util::stream::{FuturesUnordered, TryStreamExt};
17-
use tokio::time::timeout;
17+
use tokio::time::{sleep, timeout};
1818

1919
#[derive(Debug, PartialEq, Eq)]
2020
pub struct Error;
@@ -786,3 +786,121 @@ async fn test_customize_connection_acquire() {
786786
let connection_1_or_2 = pool.get().await.unwrap();
787787
assert!(connection_1_or_2.custom_field == 1 || connection_1_or_2.custom_field == 2);
788788
}
789+
790+
#[tokio::test]
791+
async fn test_customize_connection_release() {
792+
#[derive(Debug)]
793+
struct CountingCustomizer {
794+
num_conn_released: Arc<AtomicUsize>,
795+
}
796+
797+
impl CountingCustomizer {
798+
fn new(num_conn_released: Arc<AtomicUsize>) -> Self {
799+
Self { num_conn_released }
800+
}
801+
}
802+
803+
#[async_trait]
804+
impl<E: 'static> CustomizeConnection<FakeConnection, E> for CountingCustomizer {
805+
async fn on_release(&self, _connection: &mut FakeConnection) -> Result<(), E> {
806+
self.num_conn_released.fetch_add(1, Ordering::SeqCst);
807+
Ok(())
808+
}
809+
}
810+
811+
#[derive(Debug)]
812+
struct BreakableManager<C> {
813+
_c: PhantomData<C>,
814+
valid: Arc<AtomicBool>,
815+
broken: Arc<AtomicBool>,
816+
};
817+
818+
impl<C> BreakableManager<C> {
819+
fn new(valid: Arc<AtomicBool>, broken: Arc<AtomicBool>) -> Self {
820+
Self {
821+
valid,
822+
broken,
823+
_c: PhantomData,
824+
}
825+
}
826+
}
827+
828+
#[async_trait]
829+
impl<C> ManageConnection for BreakableManager<C>
830+
where
831+
C: Default + Send + Sync + 'static,
832+
{
833+
type Connection = C;
834+
type Error = Error;
835+
836+
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
837+
Ok(Default::default())
838+
}
839+
840+
async fn is_valid(
841+
&self,
842+
_conn: &mut PooledConnection<'_, Self>,
843+
) -> Result<(), Self::Error> {
844+
if self.valid.load(Ordering::SeqCst) {
845+
Ok(())
846+
} else {
847+
Err(Error)
848+
}
849+
}
850+
851+
fn has_broken(&self, _: &mut Self::Connection) -> bool {
852+
self.broken.load(Ordering::SeqCst)
853+
}
854+
}
855+
856+
let valid = Arc::new(AtomicBool::new(true));
857+
let broken = Arc::new(AtomicBool::new(false));
858+
let manager = BreakableManager::<FakeConnection>::new(valid.clone(), broken.clone());
859+
860+
let num_conn_released = Arc::new(AtomicUsize::new(0));
861+
let customizer = CountingCustomizer::new(num_conn_released.clone());
862+
863+
let pool = Pool::builder()
864+
.max_size(2)
865+
.connection_customizer(Box::new(customizer))
866+
.build(manager)
867+
.await
868+
.unwrap();
869+
870+
// Connections go in and out of the pool without being released
871+
{
872+
{
873+
let _connection_1 = pool.get().await.unwrap();
874+
let _connection_2 = pool.get().await.unwrap();
875+
assert_eq!(num_conn_released.load(Ordering::SeqCst), 0);
876+
}
877+
{
878+
let _connection_1 = pool.get().await.unwrap();
879+
let _connection_2 = pool.get().await.unwrap();
880+
assert_eq!(num_conn_released.load(Ordering::SeqCst), 0);
881+
}
882+
}
883+
884+
// Invalid connections get released
885+
{
886+
valid.store(false, Ordering::SeqCst);
887+
let _connection_1 = pool.get().await.unwrap();
888+
assert_eq!(num_conn_released.load(Ordering::SeqCst), 2);
889+
let _connection_2 = pool.get().await.unwrap();
890+
assert_eq!(num_conn_released.load(Ordering::SeqCst), 2);
891+
valid.store(true, Ordering::SeqCst);
892+
}
893+
894+
// Broken connections get released
895+
{
896+
num_conn_released.store(0, Ordering::SeqCst);
897+
broken.store(true, Ordering::SeqCst);
898+
{
899+
let _connection_1 = pool.get().await.unwrap();
900+
let _connection_2 = pool.get().await.unwrap();
901+
assert_eq!(num_conn_released.load(Ordering::SeqCst), 0);
902+
}
903+
sleep(Duration::from_millis(100)).await;
904+
assert_eq!(num_conn_released.load(Ordering::SeqCst), 2);
905+
}
906+
}

0 commit comments

Comments
 (0)