@@ -5,7 +5,7 @@ use std::iter::FromIterator;
5
5
use std:: marker:: PhantomData ;
6
6
use std:: pin:: Pin ;
7
7
use std:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
8
- use std:: sync:: Mutex ;
8
+ use std:: sync:: { Arc , Mutex } ;
9
9
use std:: task:: Poll ;
10
10
use std:: time:: Duration ;
11
11
use std:: { error, fmt, mem} ;
@@ -14,7 +14,7 @@ use async_trait::async_trait;
14
14
use futures_channel:: oneshot;
15
15
use futures_util:: future:: { err, lazy, ok, pending, ready, try_join_all, FutureExt } ;
16
16
use futures_util:: stream:: { FuturesUnordered , TryStreamExt } ;
17
- use tokio:: time:: timeout;
17
+ use tokio:: time:: { sleep , timeout} ;
18
18
19
19
#[ derive( Debug , PartialEq , Eq ) ]
20
20
pub struct Error ;
@@ -786,3 +786,121 @@ async fn test_customize_connection_acquire() {
786
786
let connection_1_or_2 = pool. get ( ) . await . unwrap ( ) ;
787
787
assert ! ( connection_1_or_2. custom_field == 1 || connection_1_or_2. custom_field == 2 ) ;
788
788
}
789
+
790
+ #[ tokio:: test]
791
+ async fn test_customize_connection_release ( ) {
792
+ #[ derive( Debug ) ]
793
+ struct CountingCustomizer {
794
+ num_conn_released : Arc < AtomicUsize > ,
795
+ }
796
+
797
+ impl CountingCustomizer {
798
+ fn new ( num_conn_released : Arc < AtomicUsize > ) -> Self {
799
+ Self { num_conn_released }
800
+ }
801
+ }
802
+
803
+ #[ async_trait]
804
+ impl < E : ' static > CustomizeConnection < FakeConnection , E > for CountingCustomizer {
805
+ async fn on_release ( & self , _connection : & mut FakeConnection ) -> Result < ( ) , E > {
806
+ self . num_conn_released . fetch_add ( 1 , Ordering :: SeqCst ) ;
807
+ Ok ( ( ) )
808
+ }
809
+ }
810
+
811
+ #[ derive( Debug ) ]
812
+ struct BreakableManager < C > {
813
+ _c : PhantomData < C > ,
814
+ valid : Arc < AtomicBool > ,
815
+ broken : Arc < AtomicBool > ,
816
+ } ;
817
+
818
+ impl < C > BreakableManager < C > {
819
+ fn new ( valid : Arc < AtomicBool > , broken : Arc < AtomicBool > ) -> Self {
820
+ Self {
821
+ valid,
822
+ broken,
823
+ _c : PhantomData ,
824
+ }
825
+ }
826
+ }
827
+
828
+ #[ async_trait]
829
+ impl < C > ManageConnection for BreakableManager < C >
830
+ where
831
+ C : Default + Send + Sync + ' static ,
832
+ {
833
+ type Connection = C ;
834
+ type Error = Error ;
835
+
836
+ async fn connect ( & self ) -> Result < Self :: Connection , Self :: Error > {
837
+ Ok ( Default :: default ( ) )
838
+ }
839
+
840
+ async fn is_valid (
841
+ & self ,
842
+ _conn : & mut PooledConnection < ' _ , Self > ,
843
+ ) -> Result < ( ) , Self :: Error > {
844
+ if self . valid . load ( Ordering :: SeqCst ) {
845
+ Ok ( ( ) )
846
+ } else {
847
+ Err ( Error )
848
+ }
849
+ }
850
+
851
+ fn has_broken ( & self , _: & mut Self :: Connection ) -> bool {
852
+ self . broken . load ( Ordering :: SeqCst )
853
+ }
854
+ }
855
+
856
+ let valid = Arc :: new ( AtomicBool :: new ( true ) ) ;
857
+ let broken = Arc :: new ( AtomicBool :: new ( false ) ) ;
858
+ let manager = BreakableManager :: < FakeConnection > :: new ( valid. clone ( ) , broken. clone ( ) ) ;
859
+
860
+ let num_conn_released = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
861
+ let customizer = CountingCustomizer :: new ( num_conn_released. clone ( ) ) ;
862
+
863
+ let pool = Pool :: builder ( )
864
+ . max_size ( 2 )
865
+ . connection_customizer ( Box :: new ( customizer) )
866
+ . build ( manager)
867
+ . await
868
+ . unwrap ( ) ;
869
+
870
+ // Connections go in and out of the pool without being released
871
+ {
872
+ {
873
+ let _connection_1 = pool. get ( ) . await . unwrap ( ) ;
874
+ let _connection_2 = pool. get ( ) . await . unwrap ( ) ;
875
+ assert_eq ! ( num_conn_released. load( Ordering :: SeqCst ) , 0 ) ;
876
+ }
877
+ {
878
+ let _connection_1 = pool. get ( ) . await . unwrap ( ) ;
879
+ let _connection_2 = pool. get ( ) . await . unwrap ( ) ;
880
+ assert_eq ! ( num_conn_released. load( Ordering :: SeqCst ) , 0 ) ;
881
+ }
882
+ }
883
+
884
+ // Invalid connections get released
885
+ {
886
+ valid. store ( false , Ordering :: SeqCst ) ;
887
+ let _connection_1 = pool. get ( ) . await . unwrap ( ) ;
888
+ assert_eq ! ( num_conn_released. load( Ordering :: SeqCst ) , 2 ) ;
889
+ let _connection_2 = pool. get ( ) . await . unwrap ( ) ;
890
+ assert_eq ! ( num_conn_released. load( Ordering :: SeqCst ) , 2 ) ;
891
+ valid. store ( true , Ordering :: SeqCst ) ;
892
+ }
893
+
894
+ // Broken connections get released
895
+ {
896
+ num_conn_released. store ( 0 , Ordering :: SeqCst ) ;
897
+ broken. store ( true , Ordering :: SeqCst ) ;
898
+ {
899
+ let _connection_1 = pool. get ( ) . await . unwrap ( ) ;
900
+ let _connection_2 = pool. get ( ) . await . unwrap ( ) ;
901
+ assert_eq ! ( num_conn_released. load( Ordering :: SeqCst ) , 0 ) ;
902
+ }
903
+ sleep ( Duration :: from_millis ( 100 ) ) . await ;
904
+ assert_eq ! ( num_conn_released. load( Ordering :: SeqCst ) , 2 ) ;
905
+ }
906
+ }
0 commit comments