Skip to content

Commit 4702fce

Browse files
committed
Add Dagger.spawn, allow on any node
1 parent 9ce4de1 commit 4702fce

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

src/thunk.jl

+14-5
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ function Base.show(io::IO, t::EagerThunk)
123123
print(io, "EagerThunk ($(isready(t) ? "finished" : "running"))")
124124
end
125125

126+
function spawn(f, args...; kwargs...)
127+
if myid() == 1
128+
Dagger.Sch.init_eager()
129+
future = ThunkFuture()
130+
uid = next_id()
131+
put!(Dagger.Sch.EAGER_THUNK_CHAN, (future, uid, f, (args...,), (kwargs...,)))
132+
EagerThunk(future, uid)
133+
else
134+
remotecall_fetch(spawn, 1, f, args...; kwargs...)
135+
end
136+
end
137+
126138
"""
127139
@par [opts] f(args...) -> Thunk
128140
@@ -172,11 +184,8 @@ function _par(ex::Expr; lazy=true, recur=true, opts=())
172184
return :(Dagger.delayed($(esc(f)); $(opts...))($(_par.(args; lazy=lazy, recur=false)...)))
173185
else
174186
return quote
175-
Dagger.Sch.init_eager()
176-
future = $ThunkFuture()
177-
uid = $next_id()
178-
put!(Dagger.Sch.EAGER_THUNK_CHAN, (future, uid, $(esc(f)), ($(_par.(args; lazy=lazy, recur=false)...),), ($(opts...),)))
179-
EagerThunk(future, uid)
187+
args = ($(_par.(args; lazy=lazy, recur=false)...),)
188+
$spawn($(esc(f)), args...; $(opts...))
180189
end
181190
end
182191
else

test/thunk.jl

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Dagger: @par, @spawn
1+
import Dagger: @par, @spawn, spawn
22

33
@everywhere checkwid() = myid()==1
44

@@ -32,7 +32,8 @@ end
3232
a = @spawn x + x
3333
@test a isa Dagger.EagerThunk
3434
b = @spawn sum([x,1,2])
35-
c = @spawn a * b
35+
c = spawn(*, a, b)
36+
@test c isa Dagger.EagerThunk
3637
@test fetch(a) == 4
3738
@test fetch(b) == 5
3839
@test fetch(c) == 20
@@ -131,4 +132,11 @@ end
131132
@test_throws_unwrap Dagger.ThunkFailedException fetch(d)
132133
end
133134
end
135+
@testset "remote spawn" begin
136+
a = fetch(Distributed.@spawnat 2 Dagger.spawn(+, 1, 2))
137+
@test Dagger.Sch.EAGER_INIT[]
138+
@test fetch(Distributed.@spawnat 2 !(Dagger.Sch.EAGER_INIT[]))
139+
@test a isa Dagger.EagerThunk
140+
@test fetch(a) == 3
141+
end
134142
end

0 commit comments

Comments
 (0)