Skip to content

Commit 1a47091

Browse files
committed
Make Dagger.spawn safely nestable
1 parent 26e70d3 commit 1a47091

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

Diff for: src/sch/eager.jl

+33-7
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,44 @@ function init_eager()
2929
end
3030
end
3131

32-
"Sets the scheduler's cached pressure indicator for the specified worker."
33-
set_pressure!(h::SchedulerHandle, pid::Int, proctype::Type, pressure::Int) =
34-
exec!(_set_pressure!, h, pid, proctype, pressure)
35-
function _set_pressure!(ctx, state, task, tid, (pid, proctype, pressure))
36-
state.worker_pressure[pid][proctype] = pressure
37-
ACTIVE_TASKS[state.uid][proctype] = pressure # HACK-ish
32+
"Adjusts the scheduler's cached pressure indicator for the specified worker by
33+
the specified amount."
34+
function adjust_pressure!(h::SchedulerHandle, proctype::Type, pressure)
35+
uid = Dagger.get_tls().sch_uid
36+
lock(ACTIVE_TASKS_LOCK) do
37+
ACTIVE_TASKS[uid][proctype][] += pressure
38+
end
39+
exec!(_adjust_pressure!, h, myid(), proctype, pressure)
40+
end
41+
function _adjust_pressure!(ctx, state, task, tid, (pid, proctype, pressure))
42+
state.worker_pressure[pid][proctype] += pressure
3843
nothing
3944
end
4045

46+
"Allows a thunk to safely wait on another thunk, by temporarily reducing its
47+
effective pressure to 0."
48+
function thunk_yield(f)
49+
if Dagger.in_thunk()
50+
h = sch_handle()
51+
tls = Dagger.get_tls()
52+
proctype = typeof(tls.processor)
53+
util = tls.utilization
54+
adjust_pressure!(h, proctype, -util)
55+
try
56+
f()
57+
finally
58+
adjust_pressure!(h, proctype, util)
59+
end
60+
else
61+
f()
62+
end
63+
end
64+
4165
function eager_thunk()
4266
h = sch_handle()
43-
set_pressure!(h, 1, Dagger.ThreadProc, 0) # HACK: Don't apply pressure from this thunk
67+
util = Dagger.get_tls().utilization
68+
# Don't apply pressure from this thunk
69+
adjust_pressure!(h, Dagger.ThreadProc, -util)
4470
while isopen(EAGER_THUNK_CHAN)
4571
try
4672
future, uid, f, args, opts = take!(EAGER_THUNK_CHAN)

Diff for: src/thunk.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,13 @@ end
9090
ThunkFuture(x::Integer) = ThunkFuture(Future(x))
9191
ThunkFuture() = ThunkFuture(Future())
9292
Base.isready(t::ThunkFuture) = isready(t.future)
93-
Base.wait(t::ThunkFuture) = wait(t.future)
93+
Base.wait(t::ThunkFuture) = Dagger.Sch.thunk_yield() do
94+
wait(t.future)
95+
end
9496
function Base.fetch(t::ThunkFuture; proc=OSProc())
95-
error, value = move(proc, fetch(t.future))
97+
error, value = Dagger.Sch.thunk_yield() do
98+
move(proc, fetch(t.future))
99+
end
96100
if error
97101
throw(value)
98102
end

Diff for: test/thunk.jl

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import Dagger: @par, @spawn, spawn
22

3-
@everywhere checkwid() = myid()==1
3+
@everywhere begin
4+
checkwid() = myid()==1
5+
function dynamic_fib(n)
6+
n <= 1 && return n
7+
t = Dagger.spawn(dynamic_fib, n-1)
8+
return (fetch(t)::Int) + dynamic_fib(n-2)
9+
end
10+
end
411

512
@testset "@par" begin
613
@testset "per-call" begin
@@ -138,5 +145,8 @@ end
138145
@test fetch(Distributed.@spawnat 2 !(Dagger.Sch.EAGER_INIT[]))
139146
@test a isa Dagger.EagerThunk
140147
@test fetch(a) == 3
148+
149+
# Mild stress-test
150+
@test fetch(dynamic_fib(10)) == 55
141151
end
142152
end

0 commit comments

Comments
 (0)