Skip to content

Commit a6c7b3f

Browse files
authored
Merge pull request #2 from postgresml/levkk-fix-waiters-2
Fix deadlock
2 parents ea5c162 + 0dc8eda commit a6c7b3f

File tree

4 files changed

+142
-92
lines changed

4 files changed

+142
-92
lines changed

bb8/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ async-trait = "0.1"
1414
futures-channel = "0.3.2"
1515
futures-util = { version = "0.3.2", default-features = false, features = ["channel"] }
1616
parking_lot = { version = "0.12", optional = true }
17-
tokio = { version = "1.0", features = ["rt", "time"] }
17+
tokio = { version = "1.0", features = ["rt", "time", "sync"] }
1818

1919
[dev-dependencies]
2020
tokio = { version = "1.0", features = ["macros"] }

bb8/src/inner.rs

+67-30
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use std::future::Future;
44
use std::sync::{Arc, Weak};
55
use std::time::{Duration, Instant};
66

7-
use futures_channel::oneshot;
87
use futures_util::stream::{FuturesUnordered, StreamExt};
98
use futures_util::TryFutureExt;
109
use tokio::spawn;
@@ -103,46 +102,84 @@ where
103102
&'a self,
104103
make_pooled_conn: F,
105104
) -> Result<PooledConnection<'b, M>, RunError<M::Error>>
105+
where
106+
F: Fn(&'a Self, Conn<M::Connection>) -> PooledConnection<'b, M>,
107+
{
108+
match timeout(
109+
self.inner.statics.connection_timeout,
110+
self.make_pooled_internal(make_pooled_conn),
111+
)
112+
.await
113+
{
114+
Ok(result) => result,
115+
_ => Err(RunError::TimedOut),
116+
}
117+
}
118+
119+
async fn make_pooled_internal<'a, 'b, F>(
120+
&'a self,
121+
make_pooled_conn: F,
122+
) -> Result<PooledConnection<'b, M>, RunError<M::Error>>
106123
where
107124
F: Fn(&'a Self, Conn<M::Connection>) -> PooledConnection<'b, M>,
108125
{
109126
loop {
110-
let mut conn = {
111-
let mut locked = self.inner.internals.lock();
112-
match locked.pop(&self.inner.statics) {
113-
Some((conn, approvals)) => {
114-
self.spawn_replenishing_approvals(approvals);
115-
make_pooled_conn(self, conn)
116-
}
117-
None => break,
127+
loop {
128+
// Get in the same queue as everyone else for a connection.
129+
let waiter = {
130+
let locked = self.inner.internals.lock();
131+
locked.request_connection()
132+
};
133+
134+
// A connection is availble, the waiter has a chance to get it.
135+
if let Some(waiter) = waiter {
136+
waiter.notified().await;
118137
}
119-
};
120138

121-
if !self.inner.statics.test_on_check_out {
122-
return Ok(conn);
123-
}
139+
// Try to get the connection if it's still availble.
140+
let mut conn = {
141+
let mut locked = self.inner.internals.lock();
124142

125-
match self.inner.manager.is_valid(&mut conn).await {
126-
Ok(()) => return Ok(conn),
127-
Err(e) => {
128-
self.inner.forward_error(e);
129-
conn.drop_invalid();
130-
continue;
143+
match locked.pop(&self.inner.statics) {
144+
Some((conn, approvals)) => {
145+
self.spawn_replenishing_approvals(approvals);
146+
make_pooled_conn(self, conn)
147+
}
148+
149+
// All open connections are gone, go make a new one and wait.
150+
None => break,
151+
}
152+
};
153+
154+
if !self.inner.statics.test_on_check_out {
155+
return Ok(conn);
156+
}
157+
158+
match self.inner.manager.is_valid(&mut conn).await {
159+
Ok(()) => return Ok(conn),
160+
Err(e) => {
161+
self.inner.statics.error_sink.sink(e);
162+
conn.drop_invalid();
163+
continue;
164+
}
131165
}
132166
}
133-
}
134167

135-
let (tx, rx) = oneshot::channel();
136-
{
137-
let mut locked = self.inner.internals.lock();
138-
let approvals = locked.push_waiter(tx, &self.inner.statics);
139-
self.spawn_replenishing_approvals(approvals);
140-
};
168+
// No connection is available, wait for one to be created for us.
169+
let waiter = {
170+
let mut locked = self.inner.internals.lock();
171+
let (waiter, approvals) = locked.push_waiter(&self.inner.statics);
172+
self.spawn_replenishing_approvals(approvals);
173+
waiter
174+
};
141175

142-
match timeout(self.inner.statics.connection_timeout, rx).await {
143-
Ok(Ok(Ok(mut guard))) => Ok(make_pooled_conn(self, guard.extract())),
144-
Ok(Ok(Err(e))) => Err(RunError::User(e)),
145-
_ => Err(RunError::TimedOut),
176+
waiter.notified().await;
177+
178+
// Did we get it? No? Let's keep waiting.
179+
match self.inner.internals.lock().pop(&self.inner.statics) {
180+
Some(conn) => return Ok(make_pooled_conn(self, conn.0)),
181+
None => continue,
182+
};
146183
}
147184
}
148185

bb8/src/internals.rs

+24-60
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33
use std::time::Instant;
44

55
use crate::{api::QueueStrategy, lock::Mutex};
6-
use futures_channel::oneshot;
6+
use tokio::sync::Notify;
77

88
use crate::api::{Builder, ManageConnection};
99
use std::collections::VecDeque;
@@ -31,15 +31,7 @@ where
3131
}
3232
}
3333

34-
pub(crate) fn forward_error(&self, mut err: M::Error) {
35-
let mut locked = self.internals.lock();
36-
while let Some(waiter) = locked.waiters.pop_front() {
37-
match waiter.send(Err(err)) {
38-
Ok(_) => return,
39-
Err(Err(e)) => err = e,
40-
Err(Ok(_)) => unreachable!(),
41-
}
42-
}
34+
pub(crate) fn forward_error(&self, err: M::Error) {
4335
self.statics.error_sink.sink(err);
4436
}
4537
}
@@ -50,7 +42,7 @@ pub(crate) struct PoolInternals<M>
5042
where
5143
M: ManageConnection,
5244
{
53-
waiters: VecDeque<oneshot::Sender<Result<InternalsGuard<M>, M::Error>>>,
45+
notify: Arc<Notify>,
5446
conns: VecDeque<IdleConn<M::Connection>>,
5547
num_conns: u32,
5648
pending_conns: u32,
@@ -82,24 +74,14 @@ where
8274

8375
let queue_strategy = pool.statics.queue_strategy;
8476

85-
let mut guard = InternalsGuard::new(conn, pool);
86-
while let Some(waiter) = self.waiters.pop_front() {
87-
// This connection is no longer idle, send it back out
88-
match waiter.send(Ok(guard)) {
89-
Ok(()) => return,
90-
Err(Ok(g)) => {
91-
guard = g;
92-
}
93-
Err(Err(_)) => unreachable!(),
94-
}
95-
}
96-
9777
// Queue it in the idle queue
98-
let conn = IdleConn::from(guard.conn.take().unwrap());
78+
let conn = IdleConn::from(conn);
9979
match queue_strategy {
10080
QueueStrategy::Fifo => self.conns.push_back(conn),
10181
QueueStrategy::Lifo => self.conns.push_front(conn),
102-
}
82+
};
83+
84+
self.notify.notify_one()
10385
}
10486

10587
pub(crate) fn connect_failed(&mut self, _: Approval) {
@@ -123,13 +105,22 @@ where
123105
self.approvals(config, wanted)
124106
}
125107

126-
pub(crate) fn push_waiter(
127-
&mut self,
128-
waiter: oneshot::Sender<Result<InternalsGuard<M>, M::Error>>,
129-
config: &Builder<M>,
130-
) -> ApprovalIter {
131-
self.waiters.push_back(waiter);
132-
self.approvals(config, 1)
108+
pub(crate) fn push_waiter(&mut self, config: &Builder<M>) -> (Arc<Notify>, ApprovalIter) {
109+
let notify = self.notify.clone();
110+
let approvals = self.approvals(config, 1);
111+
112+
(notify, approvals)
113+
}
114+
115+
pub(crate) fn request_connection(&self) -> Option<Arc<Notify>> {
116+
let notify = self.notify.clone();
117+
118+
if !self.conns.is_empty() {
119+
self.notify.notify_one();
120+
Some(notify)
121+
} else {
122+
None
123+
}
133124
}
134125

135126
fn approvals(&mut self, config: &Builder<M>, num: u32) -> ApprovalIter {
@@ -177,41 +168,14 @@ where
177168
{
178169
fn default() -> Self {
179170
Self {
180-
waiters: VecDeque::new(),
171+
notify: Arc::new(Notify::new()),
181172
conns: VecDeque::new(),
182173
num_conns: 0,
183174
pending_conns: 0,
184175
}
185176
}
186177
}
187178

188-
pub(crate) struct InternalsGuard<M: ManageConnection> {
189-
conn: Option<Conn<M::Connection>>,
190-
pool: Arc<SharedPool<M>>,
191-
}
192-
193-
impl<M: ManageConnection> InternalsGuard<M> {
194-
fn new(conn: Conn<M::Connection>, pool: Arc<SharedPool<M>>) -> Self {
195-
Self {
196-
conn: Some(conn),
197-
pool,
198-
}
199-
}
200-
201-
pub(crate) fn extract(&mut self) -> Conn<M::Connection> {
202-
self.conn.take().unwrap() // safe: can only be `None` after `Drop`
203-
}
204-
}
205-
206-
impl<M: ManageConnection> Drop for InternalsGuard<M> {
207-
fn drop(&mut self) {
208-
if let Some(conn) = self.conn.take() {
209-
let mut locked = self.pool.internals.lock();
210-
locked.put(conn, None, self.pool.clone());
211-
}
212-
}
213-
}
214-
215179
#[must_use]
216180
pub(crate) struct ApprovalIter {
217181
num: usize,

bb8/tests/test.rs

+50-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ async fn test_lazy_initialization_failure_no_retry() {
282282
.build_unchecked(manager);
283283

284284
let res = pool.get().await;
285-
assert_eq!(res.unwrap_err(), RunError::User(Error));
285+
assert_eq!(res.unwrap_err(), RunError::TimedOut);
286286
}
287287

288288
#[tokio::test]
@@ -317,6 +317,55 @@ async fn test_get_timeout() {
317317
ready(r).await.unwrap();
318318
}
319319

320+
#[tokio::test]
321+
async fn test_lots_of_waiters() {
322+
let pool = Pool::builder()
323+
.max_size(3)
324+
.connection_timeout(Duration::from_millis(5_000))
325+
.build(OkManager::<FakeConnection>::new())
326+
.await
327+
.unwrap();
328+
329+
let mut waiters: Vec<oneshot::Receiver<()>> = Vec::new();
330+
331+
for _ in 0..25000 {
332+
let pool = pool.clone();
333+
let (tx, rx) = oneshot::channel();
334+
waiters.push(rx);
335+
tokio::spawn(async move {
336+
let _conn = pool.get().await.unwrap();
337+
tx.send(()).unwrap();
338+
});
339+
}
340+
341+
let results = futures_util::future::join_all(&mut waiters).await;
342+
343+
for result in results {
344+
assert!(result.is_ok());
345+
}
346+
}
347+
348+
#[tokio::test]
349+
async fn test_timeout_caller() {
350+
let pool = Pool::builder()
351+
.max_size(1)
352+
.connection_timeout(Duration::from_millis(5_000))
353+
.build(OkManager::<FakeConnection>::new())
354+
.await
355+
.unwrap();
356+
357+
let one = pool.get().await;
358+
assert!(one.is_ok());
359+
360+
let res = tokio::time::timeout(Duration::from_millis(100), pool.get()).await;
361+
assert!(res.is_err());
362+
363+
drop(one);
364+
365+
let two = pool.get().await;
366+
assert!(two.is_ok());
367+
}
368+
320369
#[tokio::test]
321370
async fn test_now_invalid() {
322371
static INVALID: AtomicBool = AtomicBool::new(false);

0 commit comments

Comments
 (0)