Skip to content

Commit 9c01b65

Browse files
authored
Merge pull request #217 from JuliaParallel/jps/better-spawn
Add Dagger.spawn, allow on any node
2 parents 9ce4de1 + 838eb8d commit 9c01b65

File tree

8 files changed

+125
-40
lines changed

8 files changed

+125
-40
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Dagger"
22
uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
3-
version = "0.11.2"
3+
version = "0.11.3"
44

55
[deps]
66
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"

Diff for: docs/src/index.md

+19-8
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@ compute(top_node)
8282

8383
## Eager Execution
8484

85-
Similar to `@par`, Dagger has an `@spawn` macro which works similarly to
86-
`@async` and `Threads.@spawn`: when called, it wraps the function call
87-
specified by the user in an `EagerThunk` object, and immediately places it onto
88-
a running scheduler, to be executed once its dependencies are fulfilled. This
89-
contrasts with `@par` in that `@par` does not begin executing its thunks until
90-
`collect` or `compute` are called on a given thunk or one of its downstream
91-
dependencies. Additionally, one fetches the result of an `@spawn` call with
92-
`fetch`. As a concrete example:
85+
Similar to `@par`, Dagger has an `@spawn` macro (and matching `Dagger.spawn`)
86+
which works similarly to `@async` and `Threads.@spawn`: when called, it wraps
87+
the function call specified by the user in an `EagerThunk` object, and
88+
immediately places it onto a running scheduler, to be executed once its
89+
dependencies are fulfilled. This contrasts with `@par` in that `@par` does not
90+
begin executing its thunks until `collect` or `compute` are called on a given
91+
thunk or one of its downstream dependencies. Additionally, one fetches the
92+
result of any `@spawn` call with `fetch`. As a concrete example:
9393

9494
```julia
9595
x = rand(400,400)
@@ -109,6 +109,17 @@ wait(x)
109109
@info "Done!"
110110
```
111111

112+
One can also safely call `@spawn` from another worker (not id 1), and it will
113+
be sent to worker 1 to schedule:
114+
115+
```
116+
x = fetch(Distributed.@spawnat 2 Dagger.@spawn 1+2) # actually scheduled on worker 1
117+
x::EagerThunk
118+
@assert fetch(x) == 3
119+
```
120+
121+
This is useful for nested execution, where an `@spawn`'d thunk calls `@spawn`.
122+
112123
If a thunk errors while running under the eager scheduler, it will be marked as
113124
having failed, all dependent (downstream) thunks will be marked as failed, and
114125
any future thunks that use a failed thunk as input will fail. Failure can be

Diff for: src/processor.jl

+11-4
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,23 @@ rmprocs!(ctx::Context, xs::AbstractVector{<:OSProc}) = lock(ctx) do
312312
end
313313

314314
"Gets the current processor executing the current thunk."
315-
thunk_processor() = task_local_storage(:processor)::Processor
315+
thunk_processor() = task_local_storage(:_dagger_processor)::Processor
316+
317+
"Determines if we're currently in a thunk context."
318+
in_thunk() = haskey(task_local_storage(), :_dagger_sch_uid)
316319

317320
"Gets all Dagger TLS variables as a NamedTuple."
318321
get_tls() = (
322+
sch_uid=task_local_storage(:_dagger_sch_uid),
323+
sch_handle=task_local_storage(:_dagger_sch_handle),
319324
processor=thunk_processor(),
320-
sch_handle=task_local_storage(:sch_handle)
325+
utilization=task_local_storage(:_dagger_utilization),
321326
)
322327

323328
"Sets all Dagger TLS variables from a NamedTuple."
324329
function set_tls!(tls)
325-
task_local_storage(:processor, tls.processor)
326-
task_local_storage(:sch_handle, tls.sch_handle)
330+
task_local_storage(:_dagger_sch_uid, tls.sch_uid)
331+
task_local_storage(:_dagger_sch_handle, tls.sch_handle)
332+
task_local_storage(:_dagger_processor, tls.processor)
333+
task_local_storage(:_dagger_utilization, tls.utilization)
327334
end

Diff for: src/sch/Sch.jl

+9-3
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions())
290290
end
291291
end
292292
end
293-
state.worker_pressure[pid][typeof(proc)] = metadata.pressure
293+
if metadata !== nothing
294+
state.worker_pressure[pid][typeof(proc)] = metadata.pressure
295+
end
294296
node = state.thunk_dict[thunk_id]
295297
state.cache[node] = res
296298
if node.options !== nothing && node.options.checkpoint !== nothing
@@ -661,8 +663,12 @@ end
661663
res = nothing
662664
result_meta = try
663665
# Set TLS variables
664-
task_local_storage(:processor, to_proc)
665-
task_local_storage(:sch_handle, sch_handle)
666+
Dagger.set_tls!((
667+
sch_uid=uid,
668+
sch_handle=sch_handle,
669+
processor=to_proc,
670+
utilization=extra_util,
671+
))
666672

667673
# Execute
668674
res = execute!(to_proc, f, fetched...)

Diff for: src/sch/dynamic.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct SchedulerHandle
1414
end
1515

1616
"Gets the scheduler handle for the currently-executing thunk."
17-
sch_handle() = task_local_storage(:sch_handle)::SchedulerHandle
17+
sch_handle() = task_local_storage(:_dagger_sch_handle)::SchedulerHandle
1818

1919
"Thrown when the scheduler halts before finishing processing the DAG."
2020
struct SchedulerHaltedException <: Exception end
@@ -33,7 +33,7 @@ end
3333

3434
"Processes dynamic messages from worker-executing thunks."
3535
function dynamic_listener!(ctx, state)
36-
task = current_task()
36+
task = current_task() # The scheduler's main task
3737
for tid in keys(state.worker_chans)
3838
inp_chan, out_chan = state.worker_chans[tid]
3939
@async begin
@@ -71,10 +71,14 @@ end
7171

7272
## Worker-side methods for dynamic communication
7373

74+
const DYNAMIC_EXEC_LOCK = Threads.ReentrantLock()
75+
7476
"Executes an arbitrary function within the scheduler, returning the result."
7577
function exec!(f, h::SchedulerHandle, args...)
76-
put!(h.out_chan, (h.thunk_id.id, f, args))
77-
failed, res = take!(h.inp_chan)
78+
failed, res = lock(DYNAMIC_EXEC_LOCK) do
79+
put!(h.out_chan, (h.thunk_id.id, f, args))
80+
take!(h.inp_chan)
81+
end
7882
failed && throw(res)
7983
res
8084
end
@@ -138,13 +142,13 @@ end
138142

139143
"Adds a new Thunk to the DAG."
140144
add_thunk!(f, h::SchedulerHandle, args...; kwargs...) =
141-
ThunkID(exec!(_add_thunk!, h, f, args, kwargs)::Int)
145+
ThunkID(exec!(_add_thunk!, h, f, args, kwargs))
142146
function _add_thunk!(ctx, state, task, tid, (f, args, kwargs))
143147
_args = map(arg->arg isa ThunkID ? state.thunk_dict[arg.id] : arg, args)
144148
thunk = Thunk(f, _args...; kwargs...)
145149
state.thunk_dict[thunk.id] = thunk
146150
state.dependents[thunk] = Set{Thunk}()
147151
@assert reschedule_inputs!(state, thunk) || (thunk in state.errored)
148152
schedule!(ctx, state)
149-
return thunk.id
153+
return thunk.id::Int
150154
end

Diff for: src/sch/eager.jl

+34-8
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,44 @@ function init_eager()
2929
end
3030
end
3131

32-
"Sets the scheduler's cached pressure indicator for the specified worker."
33-
set_pressure!(h::SchedulerHandle, pid::Int, proctype::Type, pressure::Int) =
34-
exec!(_set_pressure!, h, pid, proctype, pressure)
35-
function _set_pressure!(ctx, state, task, tid, (pid, proctype, pressure))
36-
state.worker_pressure[pid][proctype] = pressure
37-
ACTIVE_TASKS[state.uid][proctype] = pressure # HACK-ish
38-
@show state.worker_pressure[pid]
32+
"Adjusts the scheduler's cached pressure indicator for the specified worker by
33+
the specified amount."
34+
function adjust_pressure!(h::SchedulerHandle, proctype::Type, pressure)
35+
uid = Dagger.get_tls().sch_uid
36+
lock(ACTIVE_TASKS_LOCK) do
37+
ACTIVE_TASKS[uid][proctype][] += pressure
38+
end
39+
exec!(_adjust_pressure!, h, myid(), proctype, pressure)
40+
end
41+
function _adjust_pressure!(ctx, state, task, tid, (pid, proctype, pressure))
42+
state.worker_pressure[pid][proctype] += pressure
43+
nothing
44+
end
45+
46+
"Allows a thunk to safely wait on another thunk, by temporarily reducing its
47+
effective pressure to 0."
48+
function thunk_yield(f)
49+
if Dagger.in_thunk()
50+
h = sch_handle()
51+
tls = Dagger.get_tls()
52+
proctype = typeof(tls.processor)
53+
util = tls.utilization
54+
adjust_pressure!(h, proctype, -util)
55+
try
56+
f()
57+
finally
58+
adjust_pressure!(h, proctype, util)
59+
end
60+
else
61+
f()
62+
end
3963
end
4064

4165
function eager_thunk()
4266
h = sch_handle()
43-
set_pressure!(h, 1, Dagger.ThreadProc, 0) # HACK: Don't apply pressure from this thunk
67+
util = Dagger.get_tls().utilization
68+
# Don't apply pressure from this thunk
69+
adjust_pressure!(h, Dagger.ThreadProc, -util)
4470
while isopen(EAGER_THUNK_CHAN)
4571
try
4672
future, uid, f, args, opts = take!(EAGER_THUNK_CHAN)

Diff for: src/thunk.jl

+20-7
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,13 @@ end
9090
ThunkFuture(x::Integer) = ThunkFuture(Future(x))
9191
ThunkFuture() = ThunkFuture(Future())
9292
Base.isready(t::ThunkFuture) = isready(t.future)
93-
Base.wait(t::ThunkFuture) = wait(t.future)
93+
Base.wait(t::ThunkFuture) = Dagger.Sch.thunk_yield() do
94+
wait(t.future)
95+
end
9496
function Base.fetch(t::ThunkFuture; proc=OSProc())
95-
error, value = move(proc, fetch(t.future))
97+
error, value = Dagger.Sch.thunk_yield() do
98+
move(proc, fetch(t.future))
99+
end
96100
if error
97101
throw(value)
98102
end
@@ -123,6 +127,18 @@ function Base.show(io::IO, t::EagerThunk)
123127
print(io, "EagerThunk ($(isready(t) ? "finished" : "running"))")
124128
end
125129

130+
function spawn(f, args...; kwargs...)
131+
if myid() == 1
132+
Dagger.Sch.init_eager()
133+
future = ThunkFuture()
134+
uid = next_id()
135+
put!(Dagger.Sch.EAGER_THUNK_CHAN, (future, uid, f, (args...,), (kwargs...,)))
136+
EagerThunk(future, uid)
137+
else
138+
remotecall_fetch(spawn, 1, f, args...; kwargs...)
139+
end
140+
end
141+
126142
"""
127143
@par [opts] f(args...) -> Thunk
128144
@@ -172,11 +188,8 @@ function _par(ex::Expr; lazy=true, recur=true, opts=())
172188
return :(Dagger.delayed($(esc(f)); $(opts...))($(_par.(args; lazy=lazy, recur=false)...)))
173189
else
174190
return quote
175-
Dagger.Sch.init_eager()
176-
future = $ThunkFuture()
177-
uid = $next_id()
178-
put!(Dagger.Sch.EAGER_THUNK_CHAN, (future, uid, $(esc(f)), ($(_par.(args; lazy=lazy, recur=false)...),), ($(opts...),)))
179-
EagerThunk(future, uid)
191+
args = ($(_par.(args; lazy=lazy, recur=false)...),)
192+
$spawn($(esc(f)), args...; $(opts...))
180193
end
181194
end
182195
else

Diff for: test/thunk.jl

+21-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
import Dagger: @par, @spawn
1+
import Dagger: @par, @spawn, spawn
22

3-
@everywhere checkwid() = myid()==1
3+
@everywhere begin
4+
checkwid() = myid()==1
5+
function dynamic_fib(n)
6+
n <= 1 && return n
7+
t = Dagger.spawn(dynamic_fib, n-1)
8+
return (fetch(t)::Int) + dynamic_fib(n-2)
9+
end
10+
end
411

512
@testset "@par" begin
613
@testset "per-call" begin
@@ -32,7 +39,8 @@ end
3239
a = @spawn x + x
3340
@test a isa Dagger.EagerThunk
3441
b = @spawn sum([x,1,2])
35-
c = @spawn a * b
42+
c = spawn(*, a, b)
43+
@test c isa Dagger.EagerThunk
3644
@test fetch(a) == 4
3745
@test fetch(b) == 5
3846
@test fetch(c) == 20
@@ -131,4 +139,14 @@ end
131139
@test_throws_unwrap Dagger.ThunkFailedException fetch(d)
132140
end
133141
end
142+
@testset "remote spawn" begin
143+
a = fetch(Distributed.@spawnat 2 Dagger.spawn(+, 1, 2))
144+
@test Dagger.Sch.EAGER_INIT[]
145+
@test fetch(Distributed.@spawnat 2 !(Dagger.Sch.EAGER_INIT[]))
146+
@test a isa Dagger.EagerThunk
147+
@test fetch(a) == 3
148+
149+
# Mild stress-test
150+
@test fetch(dynamic_fib(10)) == 55
151+
end
134152
end

0 commit comments

Comments
 (0)