Skip to content

Commit 5836c4d

Browse files
authored
Merge pull request #394 from JuliaParallel/jps/kwargs
Add keyword argument support
2 parents 4b83c4b + 1c8878d commit 5836c4d

20 files changed

+227
-126
lines changed

.buildkite/pipeline.yml

+11-1
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,22 @@ steps:
3535
julia_args: "--threads=1"
3636
- JuliaCI/julia-coverage#v1:
3737
codecov: true
38+
- label: Julia 1.9
39+
timeout_in_minutes: 60
40+
<<: *test
41+
plugins:
42+
- JuliaCI/julia#v1:
43+
version: "1.9"
44+
- JuliaCI/julia-test#v1:
45+
julia_args: "--threads=1"
46+
- JuliaCI/julia-coverage#v1:
47+
codecov: true
3848
- label: Julia nightly
3949
timeout_in_minutes: 60
4050
<<: *test
4151
plugins:
4252
- JuliaCI/julia#v1:
43-
version: "1.9-nightly"
53+
version: "1.10-nightly"
4454
- JuliaCI/julia-test#v1:
4555
julia_args: "--threads=1"
4656
- JuliaCI/julia-coverage#v1:

docs/src/checkpointing.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@ Let's see how we'd modify the above example to use checkpointing:
5454

5555
```julia
5656
using Serialization
57+
5758
X = compute(randn(Blocks(128,128), 1024, 1024))
58-
Y = [delayed(sum; options=Dagger.Sch.ThunkOptions(;
59-
checkpoint=(thunk,result)->begin
59+
Y = [delayed(sum; checkpoint=(thunk,result)->begin
6060
open("checkpoint-$idx.bin", "w") do io
6161
serialize(io, collect(result))
6262
end
6363
end, restore=(thunk)->begin
6464
open("checkpoint-$idx.bin", "r") do io
6565
Dagger.tochunk(deserialize(io))
6666
end
67-
end))(chunk) for (idx,chunk) in enumerate(X.chunks)]
67+
end)(chunk) for (idx,chunk) in enumerate(X.chunks)]
6868
inner(x...) = sqrt(sum(x))
6969
Z = delayed(inner)(Y...)
7070
z = collect(Z)

docs/src/index.md

+23-22
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,34 @@
22

33
## Usage
44

5-
The main function for using Dagger is `spawn`:
5+
The main entrypoint to Dagger is `@spawn`:
66

7-
`Dagger.spawn(f, args...; options...)`
7+
`Dagger.@spawn [option=value]... f(args...; kwargs...)`
88

9-
or `@spawn` for the more convenient macro form:
9+
or `spawn` if it's more convenient:
1010

11-
`Dagger.@spawn [option=value]... f(args...)`
11+
`Dagger.spawn(f, Dagger.Options(options), args...; kwargs...)`
1212

1313
When called, it creates an `EagerThunk` (also known as a "thunk" or "task")
14-
object representing a call to function `f` with the arguments `args`. If it is
15-
called with other thunks as inputs, such as in `Dagger.@spawn f(Dagger.@spawn
16-
g())`, then the function `f` gets passed the results of those input thunks. If
17-
those thunks aren't yet finished executing, then the execution of `f` waits on
18-
all of its input thunks to complete before executing.
14+
object representing a call to function `f` with the arguments `args` and
15+
keyword arguments `kwargs`. If it is called with other thunks as args/kwargs,
16+
such as in `Dagger.@spawn f(Dagger.@spawn g())`, then the function `f` gets
17+
passed the results of those input thunks, once they're available. If those
18+
thunks aren't yet finished executing, then the execution of `f` waits on all of
19+
its input thunks to complete before executing.
1920

2021
The key point is that, for each argument to a thunk, if the argument is an
2122
`EagerThunk`, it'll be executed before this node and its result will be passed
2223
into the function `f`. If the argument is *not* an `EagerThunk` (instead, some
2324
other type of Julia object), it'll be passed as-is to the function `f`.
2425

25-
Thunks don't accept regular keyword arguments for the function `f`. Instead,
26-
the `options` kwargs are passed to the scheduler to control its behavior:
26+
The `Options` struct in the second argument position is optional; if provided,
27+
it is passed to the scheduler to control its behavior. `Options` contains a
28+
`NamedTuple` of option key-value pairs, which can be any of:
2729
- Any field in `Dagger.Sch.ThunkOptions` (see [Scheduler and Thunk options](@ref))
2830
- `meta::Bool` -- Pass the input `Chunk` objects themselves to `f` and not the value contained in them
2931

30-
There are also some extra kwargs that can be passed, although they're considered advanced options to be used only by developers or library authors:
32+
There are also some extra optionss that can be passed, although they're considered advanced options to be used only by developers or library authors:
3133
- `get_result::Bool` -- return the actual result to the scheduler instead of `Chunk` objects. Used when `f` explicitly constructs a Chunk or when return value is small (e.g. in case of reduce)
3234
- `persist::Bool` -- the result of this Thunk should not be released after it becomes unused in the DAG
3335
- `cache::Bool` -- cache the result of this Thunk such that if the thunk is evaluated again, one can just reuse the cached value. If it’s been removed from cache, recompute the value.
@@ -133,18 +135,18 @@ via `@par` or `delayed`. The above computation can be executed with the lazy
133135
API by substituting `@spawn` with `@par` and `fetch` with `collect`:
134136

135137
```julia
136-
p = @par add1(4)
137-
q = @par add2(p)
138-
r = @par add1(3)
139-
s = @par combine(p, q, r)
138+
p = Dagger.@par add1(4)
139+
q = Dagger.@par add2(p)
140+
r = Dagger.@par add1(3)
141+
s = Dagger.@par combine(p, q, r)
140142

141143
@assert collect(s) == 16
142144
```
143145

144146
or similarly, in block form:
145147

146148
```julia
147-
s = @par begin
149+
s = Dagger.@par begin
148150
p = add1(4)
149151
q = add2(p)
150152
r = add1(3)
@@ -159,7 +161,7 @@ operation, you can call `compute` on the thunk. This will return a `Chunk`
159161
object which references the result (see [Chunks](@ref) for more details):
160162

161163
```julia
162-
x = @par 1+2
164+
x = Dagger.@par 1+2
163165
cx = compute(x)
164166
cx::Chunk
165167
@assert collect(cx) == 3
@@ -207,7 +209,7 @@ Scheduler options can be constructed and passed to `collect()` or `compute()`
207209
as the keyword argument `options` for lazy API usage:
208210

209211
```julia
210-
t = @par 1+2
212+
t = Dagger.@par 1+2
211213
opts = Dagger.Sch.SchedulerOptions(;single=1) # Execute on worker 1
212214

213215
compute(t; options=opts)
@@ -221,10 +223,9 @@ Thunk options can be passed to `@spawn/spawn`, `@par`, and `delayed` similarly:
221223
# Execute on worker 1
222224

223225
Dagger.@spawn single=1 1+2
224-
Dagger.spawn(+, 1, 2; single=1)
226+
Dagger.spawn(+, Dagger.Options(;single=1), 1, 2)
225227

226-
opts = Dagger.Sch.ThunkOptions(;single=1)
227-
delayed(+)(1, 2; options=opts)
228+
delayed(+; single=1)(1, 2)
228229
```
229230

230231
### Core vs. Worker Schedulers

docs/src/propagation.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Option Propagation
22

3-
Most options passed to Dagger are passed via `delayed` or `Dagger.@spawn`
3+
Most options passed to Dagger are passed via `@spawn/spawn` or `delayed`
44
directly. This works well when an option only needs to be set for a single
55
thunk, but is cumbersome when the same option needs to be set on multiple
66
thunks, or set recursively on thunks spawned within other thunks. Thankfully,

src/array/darray.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ function thunkize(ctx::Context, c::DArray; persist=true)
251251
if persist
252252
foreach(persist!, thunks)
253253
end
254-
Thunk(thunks...; meta=true) do results...
254+
Thunk(map(thunk->nothing=>thunk, thunks)...; meta=true) do results...
255255
t = eltype(results[1])
256256
DArray(t, dmn, dmnchunks,
257257
reshape(Union{Chunk,Thunk}[results...], sz))

src/array/map-reduce.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function stage(ctx::Context, node::Map)
1919
f = node.f
2020
for i=eachindex(domains)
2121
inps = map(x->chunks(x)[i], inputs)
22-
thunks[i] = Thunk((args...) -> map(f, args...), inps...)
22+
thunks[i] = Thunk((args...) -> map(f, args...), map(inp->nothing=>inp, inps)...)
2323
end
2424
DArray(Any, domain(primary), domainchunks(primary), thunks)
2525
end
@@ -40,8 +40,8 @@ end
4040

4141
function stage(ctx::Context, r::ReduceBlock)
4242
inp = stage(ctx, r.input)
43-
reduced_parts = map(x -> Thunk(r.op, x; get_result=r.get_result), chunks(inp))
44-
Thunk((xs...) -> r.op_master(xs), reduced_parts...; meta=true)
43+
reduced_parts = map(x -> Thunk(r.op, nothing=>x; get_result=r.get_result), chunks(inp))
44+
Thunk((xs...) -> r.op_master(xs), map(part->nothing=>part, reduced_parts)...; meta=true)
4545
end
4646

4747
reduceblock_async(f, x::ArrayOp; get_result=true) = ReduceBlock(f, f, x, get_result)
@@ -126,10 +126,10 @@ function stage(ctx::Context, r::Reducedim)
126126
inp = cached_stage(ctx, r.input)
127127
thunks = let op = r.op, dims=r.dims
128128
# do reducedim on each block
129-
tmp = map(p->Thunk(b->reduce(op,b,dims=dims), p), chunks(inp))
129+
tmp = map(p->Thunk(b->reduce(op,b,dims=dims), nothing=>p), chunks(inp))
130130
# combine the results in tree fashion
131131
treereducedim(tmp, r.dims) do x,y
132-
Thunk(op, x,y)
132+
Thunk(op, nothing=>x, nothing=>y)
133133
end
134134
end
135135
c = domainchunks(inp)

src/array/matrix.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ function size(x::Transpose)
1717
end
1818

1919
transpose(x::ArrayOp) = Transpose(transpose, x)
20-
transpose(x::Union{Chunk, Thunk}) = Thunk(transpose, x)
20+
transpose(x::Union{Chunk, Thunk}) = Thunk(transpose, nothing=>x)
2121

2222
adjoint(x::ArrayOp) = Transpose(adjoint, x)
23-
adjoint(x::Union{Chunk, Thunk}) = Thunk(adjoint, x)
23+
adjoint(x::Union{Chunk, Thunk}) = Thunk(adjoint, nothing=>x)
2424

2525
function adjoint(x::ArrayDomain{2})
2626
d = indexes(x)
@@ -91,8 +91,8 @@ function (+)(a::ArrayDomain, b::ArrayDomain)
9191
a
9292
end
9393

94-
(*)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(*, a,b)
95-
(+)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(+, a,b)
94+
(*)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(*, nothing=>a, nothing=>b)
95+
(+)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(+, nothing=>a, nothing=>b)
9696

9797
# we define our own matmat and matvec multiply
9898
# for computing the new domains and thunks.
@@ -211,7 +211,7 @@ end
211211
function _scale(l, r)
212212
res = similar(r, Any)
213213
for i=1:length(l)
214-
res[i,:] = map(x->Thunk((a,b) -> Diagonal(a)*b, l[i], x), r[i,:])
214+
res[i,:] = map(x->Thunk((a,b) -> Diagonal(a)*b, nothing=>l[i], nothing=>x), r[i,:])
215215
end
216216
res
217217
end

src/array/operators.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Base.@deprecate mappart(args...) mapchunk(args...)
107107
function stage(ctx::Context, node::MapChunk)
108108
inputs = map(x->cached_stage(ctx, x), node.input)
109109
thunks = map(map(chunks, inputs)...) do ps...
110-
Thunk(node.f, ps...)
110+
Thunk(node.f, map(p->nothing=>p, ps)...)
111111
end
112112

113113
DArray(Any, domain(inputs[1]), domainchunks(inputs[1]), thunks)

src/array/setindex.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function stage(ctx::Context, sidx::SetIndex)
3535
local_dmn = ArrayDomain(map(x->x[2], idx_and_dmn))
3636
s = subdmns[idx...]
3737
part_to_set = sidx.val
38-
ps[idx...] = Thunk(ps[idx...]) do p
38+
ps[idx...] = Thunk(nothing=>ps[idx...]) do p
3939
q = copy(p)
4040
q[indexes(project(s, local_dmn))...] .= part_to_set
4141
q

src/compute.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function dependents(node::Thunk)
9797
if !haskey(deps, next)
9898
deps[next] = Set{Thunk}()
9999
end
100-
for inp in inputs(next)
100+
for (_, inp) in next.inputs
101101
if istask(inp) || (inp isa Chunk)
102102
s = get!(()->Set{Thunk}(), deps, inp)
103103
push!(s, next)
@@ -165,7 +165,7 @@ function order(node::Thunk, ndeps)
165165
haskey(output, next) && continue
166166
s += 1
167167
output[next] = s
168-
parents = filter(istask, inputs(next))
168+
parents = filter(istask, map(last, next.inputs))
169169
if !isempty(parents)
170170
# If parents is empty, sort! should be a no-op, but raises an ambiguity error
171171
# when InlineStrings.jl is loaded (at least, version 1.1.0), because InlineStrings

src/processor.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ function delete_processor_callback!(name::Symbol)
3131
end
3232

3333
"""
34-
execute!(proc::Processor, f, args...) -> Any
34+
execute!(proc::Processor, f, args...; kwargs...) -> Any
3535
36-
Executes the function `f` with arguments `args` on processor `proc`. This
37-
function can be overloaded by `Processor` subtypes to allow executing function
38-
calls differently than normal Julia.
36+
Executes the function `f` with arguments `args` and keyword arguments `kwargs`
37+
on processor `proc`. This function can be overloaded by `Processor` subtypes to
38+
allow executing function calls differently than normal Julia.
3939
"""
4040
function execute! end
4141

@@ -154,12 +154,12 @@ end
154154
iscompatible(proc::ThreadProc, opts, f, args...) = true
155155
iscompatible_func(proc::ThreadProc, opts, f) = true
156156
iscompatible_arg(proc::ThreadProc, opts, x) = true
157-
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...))
157+
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @nospecialize(kwargs...))
158158
tls = get_tls()
159159
task = Task() do
160160
set_tls!(tls)
161161
TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id)
162-
@invokelatest f(args...)
162+
@invokelatest f(args...; kwargs...)
163163
end
164164
task.sticky = true
165165
ret = ccall(:jl_set_task_tid, Cint, (Any, Cint), task, proc.tid-1)

0 commit comments

Comments
 (0)