Skip to content

Commit ec6b048

Browse files
committed
async_hooks: fix async/await context loss in AsyncLocalStorage
1 parent 9545013 commit ec6b048

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

lib/async_hooks.js

+34-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
const {
44
NumberIsSafeInteger,
5+
PromiseResolve,
56
ReflectApply,
67
Symbol,
78
} = primordials;
@@ -211,13 +212,44 @@ class AsyncResource {
211212
}
212213

213214
const storageList = [];
215+
const seenLayer = [];
216+
let depth = 0;
217+
218+
function patchPromiseBarrier(currentResource) {
219+
PromiseResolve({
220+
then(resolve) {
221+
const resource = executionAsyncResource();
222+
propagateToStorageLists(resource, currentResource);
223+
resolve();
224+
}
225+
});
226+
}
227+
228+
function propagateToStorageLists(resource, currentResource) {
229+
for (let i = 0; i < storageList.length; ++i) {
230+
storageList[i]._propagate(resource, currentResource);
231+
}
232+
}
233+
214234
const storageHook = createHook({
215235
init(asyncId, type, triggerAsyncId, resource) {
216236
const currentResource = executionAsyncResource();
217237
// Value of currentResource is always a non null object
218-
for (let i = 0; i < storageList.length; ++i) {
219-
storageList[i]._propagate(resource, currentResource);
238+
propagateToStorageLists(resource, currentResource);
239+
240+
if (type === 'PROMISE' && !seenLayer[depth]) {
241+
seenLayer[depth] = true;
242+
patchPromiseBarrier(currentResource);
220243
}
244+
},
245+
246+
before(asyncId) {
247+
depth++;
248+
seenLayer[depth] = false;
249+
},
250+
251+
after(asyncId) {
252+
depth--;
221253
}
222254
});
223255

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
'use strict';
2+
const common = require('../common');
3+
const assert = require('assert');
4+
const { AsyncLocalStorage } = require('async_hooks');
5+
6+
const store = new AsyncLocalStorage();
7+
let checked = 0;
8+
9+
function thenable(expected, count) {
10+
return {
11+
then: common.mustCall((cb) => {
12+
assert.strictEqual(expected, store.getStore());
13+
checked++;
14+
cb();
15+
}, count)
16+
};
17+
}
18+
19+
function main(n) {
20+
const firstData = Symbol('first-data');
21+
const secondData = Symbol('second-data');
22+
23+
const first = thenable(firstData, 1);
24+
const second = thenable(secondData, 1);
25+
const third = thenable(firstData, 2);
26+
27+
return store.run(firstData, common.mustCall(async () => {
28+
assert.strictEqual(firstData, store.getStore());
29+
await first;
30+
31+
await store.run(secondData, common.mustCall(async () => {
32+
assert.strictEqual(secondData, store.getStore());
33+
await second;
34+
assert.strictEqual(secondData, store.getStore());
35+
}));
36+
37+
await Promise.all([ third, third ]);
38+
assert.strictEqual(firstData, store.getStore());
39+
}));
40+
}
41+
42+
const outerData = Symbol('outer-data');
43+
44+
Promise.all([
45+
store.run(outerData, () => Promise.resolve(thenable(outerData))),
46+
Promise.resolve(3).then(common.mustCall(main)),
47+
main(1),
48+
main(2)
49+
]).then(common.mustCall(() => {
50+
assert.strictEqual(checked, 13);
51+
}));

0 commit comments

Comments
 (0)