Skip to content

Commit 87f8bdb

Browse files
authored
Merge pull request #81 from jnak/concurrency-bug
Make Promise and Dataloader thread-safe
2 parents 0d5366e + c519261 commit 87f8bdb

File tree

3 files changed

+124
-8
lines changed

3 files changed

+124
-8
lines changed

promise/async_.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Based on https://github.com/petkaantonov/bluebird/blob/master/src/promise.js
22
from collections import deque
3+
from threading import local
34

45
if False:
56
from .promise import Promise
67
from typing import Any, Callable, Optional, Union # flake8: noqa
78

89

9-
class Async(object):
10+
class Async(local):
1011
def __init__(self, trampoline_enabled=True):
1112
self.is_tick_used = False
1213
self.late_queue = deque() # type: ignore

promise/dataloader.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
except ImportError:
55
from collections import Iterable
66
from functools import partial
7+
from threading import local
78

89
from .promise import Promise, async_instance, get_default_scheduler
910

@@ -33,7 +34,7 @@ def get_chunks(iterable_obj, chunk_size=1):
3334
Loader = namedtuple("Loader", "key,resolve,reject")
3435

3536

36-
class DataLoader(object):
37+
class DataLoader(local):
3738

3839
batch = True
3940
max_batch_size = None # type: int
@@ -212,22 +213,21 @@ def prime(self, key, value):
212213
# ensuring that it always occurs after "PromiseJobs" ends.
213214

214215
# Private: cached resolved Promise instance
215-
resolved_promise = None # type: Optional[Promise[None]]
216-
216+
cache = local()
217217

218218
def enqueue_post_promise_job(fn, scheduler):
219219
# type: (Callable, Any) -> None
220-
global resolved_promise
221-
if not resolved_promise:
222-
resolved_promise = Promise.resolve(None)
220+
global cache
221+
if not hasattr(cache, 'resolved_promise'):
222+
cache.resolved_promise = Promise.resolve(None)
223223
if not scheduler:
224224
scheduler = get_default_scheduler()
225225

226226
def on_promise_resolve(v):
227227
# type: (Any) -> None
228228
async_instance.invoke(fn, scheduler)
229229

230-
resolved_promise.then(on_promise_resolve)
230+
cache.resolved_promise.then(on_promise_resolve)
231231

232232

233233
def dispatch_queue(loader):

tests/test_thread_safety.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from promise import Promise
2+
from promise.dataloader import DataLoader
3+
import threading
4+
5+
6+
7+
def test_promise_thread_safety():
8+
"""
9+
Promise tasks should never be executed in a different thread from the one they are scheduled from,
10+
unless the ThreadPoolExecutor is used.
11+
12+
Here we assert that the pending promise tasks on thread 1 are not executed on thread 2 as thread 2
13+
resolves its own promise tasks.
14+
"""
15+
event_1 = threading.Event()
16+
event_2 = threading.Event()
17+
18+
assert_object = {'is_same_thread': True}
19+
20+
def task_1():
21+
thread_name = threading.current_thread().getName()
22+
23+
def then_1(value):
24+
# Enqueue tasks to run later.
25+
# This relies on the fact that `then` does not execute the function synchronously when called from
26+
# within another `then` callback function.
27+
promise = Promise.resolve(None).then(then_2)
28+
assert promise.is_pending
29+
event_1.set() # Unblock main thread
30+
event_2.wait() # Wait for thread 2
31+
32+
def then_2(value):
33+
assert_object['is_same_thread'] = (thread_name == threading.current_thread().getName())
34+
35+
promise = Promise.resolve(None).then(then_1)
36+
37+
def task_2():
38+
promise = Promise.resolve(None).then(lambda v: None)
39+
promise.get() # Drain task queue
40+
event_2.set() # Unblock thread 1
41+
42+
thread_1 = threading.Thread(target=task_1)
43+
thread_1.start()
44+
45+
event_1.wait() # Wait for Thread 1 to enqueue promise tasks
46+
47+
thread_2 = threading.Thread(target=task_2)
48+
thread_2.start()
49+
50+
for thread in (thread_1, thread_2):
51+
thread.join()
52+
53+
assert assert_object['is_same_thread']
54+
55+
56+
def test_dataloader_thread_safety():
57+
"""
58+
Dataloader should only batch `load` calls that happened on the same thread.
59+
60+
Here we assert that `load` calls on thread 2 are not batched on thread 1 as
61+
thread 1 batches its own `load` calls.
62+
"""
63+
def load_many(keys):
64+
thead_name = threading.current_thread().getName()
65+
return Promise.resolve([thead_name for key in keys])
66+
67+
thread_name_loader = DataLoader(load_many)
68+
69+
event_1 = threading.Event()
70+
event_2 = threading.Event()
71+
event_3 = threading.Event()
72+
73+
assert_object = {
74+
'is_same_thread_1': True,
75+
'is_same_thread_2': True,
76+
}
77+
78+
def task_1():
79+
@Promise.safe
80+
def do():
81+
promise = thread_name_loader.load(1)
82+
event_1.set()
83+
event_2.wait() # Wait for thread 2 to call `load`
84+
assert_object['is_same_thread_1'] = (
85+
promise.get() == threading.current_thread().getName()
86+
)
87+
event_3.set() # Unblock thread 2
88+
89+
do().get()
90+
91+
def task_2():
92+
@Promise.safe
93+
def do():
94+
promise = thread_name_loader.load(2)
95+
event_2.set()
96+
event_3.wait() # Wait for thread 1 to run `dispatch_queue_batch`
97+
assert_object['is_same_thread_2'] = (
98+
promise.get() == threading.current_thread().getName()
99+
)
100+
101+
do().get()
102+
103+
thread_1 = threading.Thread(target=task_1)
104+
thread_1.start()
105+
106+
event_1.wait() # Wait for thread 1 to call `load`
107+
108+
thread_2 = threading.Thread(target=task_2)
109+
thread_2.start()
110+
111+
for thread in (thread_1, thread_2):
112+
thread.join()
113+
114+
assert assert_object['is_same_thread_1']
115+
assert assert_object['is_same_thread_2']

0 commit comments

Comments
 (0)