Skip to content

Commit 1497a24

Browse files
authored
Merge pull request #575 from JuliaParallel/jps/parser-fix-bcast
parser: Re-support broadcast
2 parents 9bdf2d7 + 656f952 commit 1497a24

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

Diff for: src/thunk.jl

+13-8
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ generated thunks.
307307
macro par(exs...)
308308
opts = exs[1:end-1]
309309
ex = exs[end]
310-
return esc(_par(ex; lazy=true, opts=opts))
310+
return esc(_par(__module__, ex; lazy=true, opts=opts))
311311
end
312312

313313
"""
@@ -349,7 +349,7 @@ also passes along any options in an `Options` struct. For example,
349349
macro spawn(exs...)
350350
opts = exs[1:end-1]
351351
ex = exs[end]
352-
return esc(_par(ex; lazy=false, opts=opts))
352+
return esc(_par(__module__, ex; lazy=false, opts=opts))
353353
end
354354

355355
struct ExpandedBroadcast{F} end
@@ -365,8 +365,9 @@ end
365365

366366
to_namedtuple(;kwargs...) = (;kwargs...)
367367

368-
function _par(ex::Expr; lazy=true, recur=true, opts=())
368+
function _par(mod, ex::Expr; lazy=true, recur=true, opts=())
369369
f = nothing
370+
bf = nothing
370371
body = nothing
371372
arg1 = nothing
372373
arg2 = nothing
@@ -375,7 +376,11 @@ function _par(ex::Expr; lazy=true, recur=true, opts=())
375376
@capture(ex, allargs__->body_) ||
376377
@capture(ex, arg1_[allargs__]) ||
377378
@capture(ex, arg1_.arg2_) ||
378-
@capture(ex, (;allargs__))
379+
@capture(ex, (;allargs__)) ||
380+
@capture(ex, bf_.(allargs__))
381+
if bf !== nothing
382+
f = ExpandedBroadcast{mod.eval(bf)}()
383+
end
379384
f = replace_broadcast(f)
380385
if arg1 !== nothing
381386
if arg2 !== nothing
@@ -429,15 +434,15 @@ function _par(ex::Expr; lazy=true, recur=true, opts=())
429434
end
430435
elseif lazy
431436
# Recurse into the expression
432-
return Expr(ex.head, _par_inner.(ex.args, lazy=lazy, recur=recur, opts=opts)...)
437+
return Expr(ex.head, _par_inner.(Ref(mod), ex.args, lazy=lazy, recur=recur, opts=opts)...)
433438
else
434439
throw(ArgumentError("Invalid Dagger task expression: $ex"))
435440
end
436441
end
437-
_par(ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex"))
442+
_par(mod, ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex"))
438443

439-
_par_inner(ex; kwargs...) = ex
440-
_par_inner(ex::Expr; kwargs...) = _par(ex; kwargs...)
444+
_par_inner(mod, ex; kwargs...) = ex
445+
_par_inner(mod, ex::Expr; kwargs...) = _par(mod, ex; kwargs...)
441446

442447
"""
443448
Dagger.spawn(f, args...; kwargs...) -> DTask

Diff for: test/thunk.jl

+7
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ end
160160
@test t isa Dagger.DTask
161161
@test fetch(t) == fetch(nt2).b
162162
end
163+
@testset "broadcast" begin
164+
x = randn(100)
165+
166+
t = @spawn abs.(x)
167+
@test t isa Dagger.DTask
168+
@test fetch(t) == abs.(x)
169+
end
163170
@testset "invalid expression" begin
164171
@test_throws LoadError eval(:(@spawn 1))
165172
@test_throws LoadError eval(:(@spawn begin 1 end))

0 commit comments

Comments
 (0)