Skip to content

Commit fc39fd3

Browse files
authored
Allow users to customize connections created by the pool (#89, fixes #88)
1 parent a89e062 commit fc39fd3

File tree

4 files changed

+99
-3
lines changed

4 files changed

+99
-3
lines changed

bb8/src/api.rs

+38
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ pub struct Builder<M: ManageConnection> {
8484
pub(crate) error_sink: Box<dyn ErrorSink<M::Error>>,
8585
/// The time interval used to wake up and reap connections.
8686
pub(crate) reaper_rate: Duration,
87+
/// User-supplied trait object responsible for initializing connections
88+
pub(crate) connection_customizer: Box<dyn CustomizeConnection<M::Connection, M::Error>>,
8789
_p: PhantomData<M>,
8890
}
8991

@@ -98,6 +100,7 @@ impl<M: ManageConnection> Default for Builder<M> {
98100
connection_timeout: Duration::from_secs(30),
99101
error_sink: Box::new(NopErrorSink),
100102
reaper_rate: Duration::from_secs(30),
103+
connection_customizer: Box::new(NopConnectionCustomizer {}),
101104
_p: PhantomData,
102105
}
103106
}
@@ -204,6 +207,18 @@ impl<M: ManageConnection> Builder<M> {
204207
self
205208
}
206209

210+
/// Set the connection customizer which will be used to initialize
211+
/// connections created by the pool.
212+
///
213+
/// Defaults to `NopConnectionCustomizer`.
214+
pub fn connection_customizer(
215+
mut self,
216+
connection_customizer: Box<dyn CustomizeConnection<M::Connection, M::Error>>,
217+
) -> Builder<M> {
218+
self.connection_customizer = connection_customizer;
219+
self
220+
}
221+
207222
fn build_inner(self, manager: M) -> Pool<M> {
208223
if let Some(min_idle) = self.min_idle {
209224
assert!(
@@ -253,6 +268,29 @@ pub trait ManageConnection: Sized + Send + Sync + 'static {
253268
fn has_broken(&self, conn: &mut Self::Connection) -> bool;
254269
}
255270

271+
/// A trait which provides functionality to initialize a connection
272+
#[async_trait]
273+
pub trait CustomizeConnection<C: Send + 'static, E: 'static>:
274+
std::fmt::Debug + Send + Sync + 'static
275+
{
276+
/// Called with connections immediately after they are returned from
277+
/// `ManageConnection::connect`.
278+
///
279+
/// The default implementation simply returns `Ok(())`.
280+
///
281+
/// # Errors
282+
///
283+
/// If this method returns an error, the connection will be discarded.
284+
#[allow(unused_variables)]
285+
async fn on_acquire(&self, connection: &mut C) -> Result<(), E> {
286+
Ok(())
287+
}
288+
}
289+
290+
#[derive(Copy, Clone, Debug)]
291+
struct NopConnectionCustomizer;
292+
impl<C: Send + 'static, E: 'static> CustomizeConnection<C, E> for NopConnectionCustomizer {}
293+
256294
/// A smart pointer wrapping a connection.
257295
pub struct PooledConnection<'a, M>
258296
where

bb8/src/inner.rs

+19-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::time::{Duration, Instant};
66

77
use futures_channel::oneshot;
88
use futures_util::stream::{FuturesUnordered, StreamExt};
9+
use futures_util::TryFutureExt;
910
use parking_lot::Mutex;
1011
use tokio::spawn;
1112
use tokio::time::{interval_at, sleep, timeout, Interval};
@@ -126,7 +127,9 @@ where
126127
}
127128

128129
pub(crate) async fn connect(&self) -> Result<M::Connection, M::Error> {
129-
self.inner.manager.connect().await
130+
let mut conn = self.inner.manager.connect().await?;
131+
self.on_acquire_connection(&mut conn).await?;
132+
Ok(conn)
130133
}
131134

132135
/// Return connection back in to the pool
@@ -174,7 +177,13 @@ where
174177
let start = Instant::now();
175178
let mut delay = Duration::from_secs(0);
176179
loop {
177-
match shared.manager.connect().await {
180+
let conn = shared
181+
.manager
182+
.connect()
183+
.and_then(|mut c| async { self.on_acquire_connection(&mut c).await.map(|_| c) })
184+
.await;
185+
186+
match conn {
178187
Ok(conn) => {
179188
let conn = Conn::new(conn);
180189
shared.internals.lock().put(conn, Some(approval));
@@ -194,6 +203,14 @@ where
194203
}
195204
}
196205
}
206+
207+
async fn on_acquire_connection(&self, conn: &mut M::Connection) -> Result<(), M::Error> {
208+
self.inner
209+
.statics
210+
.connection_customizer
211+
.on_acquire(conn)
212+
.await
213+
}
197214
}
198215

199216
impl<M> Clone for PoolInner<M>

bb8/src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535

3636
mod api;
3737
pub use api::{
38-
Builder, ErrorSink, ManageConnection, NopErrorSink, Pool, PooledConnection, RunError, State,
38+
Builder, CustomizeConnection, ErrorSink, ManageConnection, NopErrorSink, Pool,
39+
PooledConnection, RunError, State,
3940
};
4041

4142
mod inner;

bb8/tests/test.rs

+40
Original file line numberDiff line numberDiff line change
@@ -746,3 +746,43 @@ async fn test_guard() {
746746
tx4.send(()).unwrap();
747747
tx6.send(()).unwrap();
748748
}
749+
750+
#[tokio::test]
751+
async fn test_customize_connection_acquire() {
752+
#[derive(Debug, Default)]
753+
struct Connection {
754+
custom_field: usize,
755+
};
756+
757+
#[derive(Debug, Default)]
758+
struct CountingCustomizer {
759+
count: std::sync::atomic::AtomicUsize,
760+
}
761+
762+
#[async_trait]
763+
impl<E: 'static> CustomizeConnection<Connection, E> for CountingCustomizer {
764+
async fn on_acquire(&self, connection: &mut Connection) -> Result<(), E> {
765+
connection.custom_field = 1 + self.count.fetch_add(1, Ordering::SeqCst);
766+
Ok(())
767+
}
768+
}
769+
770+
let pool = Pool::builder()
771+
.max_size(2)
772+
.connection_customizer(Box::new(CountingCustomizer::default()))
773+
.build(OkManager::<Connection>::new())
774+
.await
775+
.unwrap();
776+
777+
// Each connection gets customized
778+
{
779+
let connection_1 = pool.get().await.unwrap();
780+
assert_eq!(connection_1.custom_field, 1);
781+
let connection_2 = pool.get().await.unwrap();
782+
assert_eq!(connection_2.custom_field, 2);
783+
}
784+
785+
// Connections don't get customized again on re-use
786+
let connection_1_or_2 = pool.get().await.unwrap();
787+
assert!(connection_1_or_2.custom_field == 1 || connection_1_or_2.custom_field == 2);
788+
}

0 commit comments

Comments
 (0)