Skip to content

Commit 7c3eef0

Browse files
committed
Add keyword argument support
APIs like `delayed` and `spawn` assumed that passed kwargs were to be treated as options to the scheduler, which is both somewhat confusing for users, and precludes passing kwargs to user functions. This commit changes those APIs, as well as `@spawn`, to instead pass kwargs directly to the user's function. Options are now passed in an `Options` struct to `delayed` and `spawn` as the second argument (the first being the function), while `@spawn` still keeps them before the call (which is generally more convenient). Internally, `Thunk`'s `inputs` field is now a `Vector{Pair{Union{Symbol,Nothing},Any}}`, where the second element of each pair is the argument, while the first element is a position; if `nothing`, it's a positional argument, and if a `Symbol`, then it's a kwarg.
1 parent 4b83c4b commit 7c3eef0

14 files changed

+203
-112
lines changed

docs/src/checkpointing.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@ Let's see how we'd modify the above example to use checkpointing:
5454

5555
```julia
5656
using Serialization
57+
5758
X = compute(randn(Blocks(128,128), 1024, 1024))
58-
Y = [delayed(sum; options=Dagger.Sch.ThunkOptions(;
59-
checkpoint=(thunk,result)->begin
59+
Y = [delayed(sum; checkpoint=(thunk,result)->begin
6060
open("checkpoint-$idx.bin", "w") do io
6161
serialize(io, collect(result))
6262
end
6363
end, restore=(thunk)->begin
6464
open("checkpoint-$idx.bin", "r") do io
6565
Dagger.tochunk(deserialize(io))
6666
end
67-
end))(chunk) for (idx,chunk) in enumerate(X.chunks)]
67+
end)(chunk) for (idx,chunk) in enumerate(X.chunks)]
6868
inner(x...) = sqrt(sum(x))
6969
Z = delayed(inner)(Y...)
7070
z = collect(Z)

docs/src/index.md

+23-22
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,34 @@
22

33
## Usage
44

5-
The main function for using Dagger is `spawn`:
5+
The main entrypoint to Dagger is `@spawn`:
66

7-
`Dagger.spawn(f, args...; options...)`
7+
`Dagger.@spawn [option=value]... f(args...; kwargs...)`
88

9-
or `@spawn` for the more convenient macro form:
9+
or `spawn` if it's more convenient:
1010

11-
`Dagger.@spawn [option=value]... f(args...)`
11+
`Dagger.spawn(f, Dagger.Options(options), args...; kwargs...)`
1212

1313
When called, it creates an `EagerThunk` (also known as a "thunk" or "task")
14-
object representing a call to function `f` with the arguments `args`. If it is
15-
called with other thunks as inputs, such as in `Dagger.@spawn f(Dagger.@spawn
16-
g())`, then the function `f` gets passed the results of those input thunks. If
17-
those thunks aren't yet finished executing, then the execution of `f` waits on
18-
all of its input thunks to complete before executing.
14+
object representing a call to function `f` with the arguments `args` and
15+
keyword arguments `kwargs`. If it is called with other thunks as args/kwargs,
16+
such as in `Dagger.@spawn f(Dagger.@spawn g())`, then the function `f` gets
17+
passed the results of those input thunks, once they're available. If those
18+
thunks aren't yet finished executing, then the execution of `f` waits on all of
19+
its input thunks to complete before executing.
1920

2021
The key point is that, for each argument to a thunk, if the argument is an
2122
`EagerThunk`, it'll be executed before this node and its result will be passed
2223
into the function `f`. If the argument is *not* an `EagerThunk` (instead, some
2324
other type of Julia object), it'll be passed as-is to the function `f`.
2425

25-
Thunks don't accept regular keyword arguments for the function `f`. Instead,
26-
the `options` kwargs are passed to the scheduler to control its behavior:
26+
The `Options` struct in the second argument position is optional; if provided,
27+
it is passed to the scheduler to control its behavior. `Options` contains a
28+
`NamedTuple` of option key-value pairs, which can be any of:
2729
- Any field in `Dagger.Sch.ThunkOptions` (see [Scheduler and Thunk options](@ref))
2830
- `meta::Bool` -- Pass the input `Chunk` objects themselves to `f` and not the value contained in them
2931

30-
There are also some extra kwargs that can be passed, although they're considered advanced options to be used only by developers or library authors:
32+
There are also some extra optionss that can be passed, although they're considered advanced options to be used only by developers or library authors:
3133
- `get_result::Bool` -- return the actual result to the scheduler instead of `Chunk` objects. Used when `f` explicitly constructs a Chunk or when return value is small (e.g. in case of reduce)
3234
- `persist::Bool` -- the result of this Thunk should not be released after it becomes unused in the DAG
3335
- `cache::Bool` -- cache the result of this Thunk such that if the thunk is evaluated again, one can just reuse the cached value. If it’s been removed from cache, recompute the value.
@@ -133,18 +135,18 @@ via `@par` or `delayed`. The above computation can be executed with the lazy
133135
API by substituting `@spawn` with `@par` and `fetch` with `collect`:
134136

135137
```julia
136-
p = @par add1(4)
137-
q = @par add2(p)
138-
r = @par add1(3)
139-
s = @par combine(p, q, r)
138+
p = Dagger.@par add1(4)
139+
q = Dagger.@par add2(p)
140+
r = Dagger.@par add1(3)
141+
s = Dagger.@par combine(p, q, r)
140142

141143
@assert collect(s) == 16
142144
```
143145

144146
or similarly, in block form:
145147

146148
```julia
147-
s = @par begin
149+
s = Dagger.@par begin
148150
p = add1(4)
149151
q = add2(p)
150152
r = add1(3)
@@ -159,7 +161,7 @@ operation, you can call `compute` on the thunk. This will return a `Chunk`
159161
object which references the result (see [Chunks](@ref) for more details):
160162

161163
```julia
162-
x = @par 1+2
164+
x = Dagger.@par 1+2
163165
cx = compute(x)
164166
cx::Chunk
165167
@assert collect(cx) == 3
@@ -207,7 +209,7 @@ Scheduler options can be constructed and passed to `collect()` or `compute()`
207209
as the keyword argument `options` for lazy API usage:
208210

209211
```julia
210-
t = @par 1+2
212+
t = Dagger.@par 1+2
211213
opts = Dagger.Sch.SchedulerOptions(;single=1) # Execute on worker 1
212214

213215
compute(t; options=opts)
@@ -221,10 +223,9 @@ Thunk options can be passed to `@spawn/spawn`, `@par`, and `delayed` similarly:
221223
# Execute on worker 1
222224

223225
Dagger.@spawn single=1 1+2
224-
Dagger.spawn(+, 1, 2; single=1)
226+
Dagger.spawn(+, Dagger.Options(;single=1), 1, 2)
225227

226-
opts = Dagger.Sch.ThunkOptions(;single=1)
227-
delayed(+)(1, 2; options=opts)
228+
delayed(+; single=1)(1, 2)
228229
```
229230

230231
### Core vs. Worker Schedulers

docs/src/propagation.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Option Propagation
22

3-
Most options passed to Dagger are passed via `delayed` or `Dagger.@spawn`
3+
Most options passed to Dagger are passed via `@spawn/spawn` or `delayed`
44
directly. This works well when an option only needs to be set for a single
55
thunk, but is cumbersome when the same option needs to be set on multiple
66
thunks, or set recursively on thunks spawned within other thunks. Thankfully,

src/compute.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function dependents(node::Thunk)
9797
if !haskey(deps, next)
9898
deps[next] = Set{Thunk}()
9999
end
100-
for inp in inputs(next)
100+
for (_, inp) in next.inputs
101101
if istask(inp) || (inp isa Chunk)
102102
s = get!(()->Set{Thunk}(), deps, inp)
103103
push!(s, next)
@@ -165,7 +165,7 @@ function order(node::Thunk, ndeps)
165165
haskey(output, next) && continue
166166
s += 1
167167
output[next] = s
168-
parents = filter(istask, inputs(next))
168+
parents = filter(istask, map(last, next.inputs))
169169
if !isempty(parents)
170170
# If parents is empty, sort! should be a no-op, but raises an ambiguity error
171171
# when InlineStrings.jl is loaded (at least, version 1.1.0), because InlineStrings

src/processor.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ function delete_processor_callback!(name::Symbol)
3131
end
3232

3333
"""
34-
execute!(proc::Processor, f, args...) -> Any
34+
execute!(proc::Processor, f, args...; kwargs...) -> Any
3535
36-
Executes the function `f` with arguments `args` on processor `proc`. This
37-
function can be overloaded by `Processor` subtypes to allow executing function
38-
calls differently than normal Julia.
36+
Executes the function `f` with arguments `args` and keyword arguments `kwargs`
37+
on processor `proc`. This function can be overloaded by `Processor` subtypes to
38+
allow executing function calls differently than normal Julia.
3939
"""
4040
function execute! end
4141

@@ -154,12 +154,12 @@ end
154154
iscompatible(proc::ThreadProc, opts, f, args...) = true
155155
iscompatible_func(proc::ThreadProc, opts, f) = true
156156
iscompatible_arg(proc::ThreadProc, opts, x) = true
157-
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...))
157+
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @nospecialize(kwargs...))
158158
tls = get_tls()
159159
task = Task() do
160160
set_tls!(tls)
161161
TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id)
162-
@invokelatest f(args...)
162+
@invokelatest f(args...; kwargs...)
163163
end
164164
task.sticky = true
165165
ret = ccall(:jl_set_task_tid, Cint, (Any, Cint), task, proc.tid-1)

src/sch/Sch.jl

+39-13
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ function start_state(deps::Dict, node_order, chan)
130130

131131
for k in sort(collect(keys(deps)), by=node_order)
132132
if istask(k)
133-
waiting = Set{Thunk}(Iterators.filter(istask, inputs(k)))
133+
waiting = Set{Thunk}(Iterators.filter(istask, map(last, inputs(k))))
134134
if isempty(waiting)
135135
push!(state.ready, k)
136136
else
@@ -659,7 +659,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
659659
DefaultScope()
660660
end
661661
end
662-
for input in task.inputs
662+
for (_,input) in task.inputs
663663
input = unwrap_weak_checked(input)
664664
chunk = if istask(input)
665665
state.cache[input]
@@ -688,7 +688,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
688688
@goto fallback
689689
end
690690

691-
inputs = collect_task_inputs(state, task)
691+
inputs = map(last, collect_task_inputs(state, task))
692692
opts = populate_defaults(opts, chunktype(task.f), map(chunktype, inputs))
693693
local_procs, costs = estimate_task_costs(state, local_procs, task, inputs)
694694
scheduled = false
@@ -945,10 +945,13 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)
945945

946946
ids = Int[0]
947947
data = Any[thunk.f]
948-
for (idx, x) in enumerate(thunk.inputs)
948+
positions = Union{Symbol,Nothing}[]
949+
for (idx, pos_x) in enumerate(thunk.inputs)
950+
pos, x = pos_x
949951
x = unwrap_weak_checked(x)
950952
push!(ids, istask(x) ? x.id : -idx)
951953
push!(data, istask(x) ? state.cache[x] : x)
954+
push!(positions, pos)
952955
end
953956
toptions = thunk.options !== nothing ? thunk.options : ThunkOptions()
954957
options = merge(ctx.options, toptions)
@@ -961,7 +964,7 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)
961964
push!(to_send, Any[thunk.id, time_util, alloc_util, occupancy,
962965
scope, chunktype(thunk.f), data,
963966
thunk.get_result, thunk.persist, thunk.cache, thunk.meta, options,
964-
propagated, ids,
967+
propagated, ids, positions,
965968
(log_sink=ctx.log_sink, profile=ctx.profile),
966969
sch_handle, state.uid])
967970
end
@@ -1093,6 +1096,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
10931096

10941097
while isopen(return_queue)
10951098
# Wait for new tasks
1099+
@dagdebug nothing :processor "Waiting for tasks"
10961100
timespan_start(ctx, :proc_run_wait, to_proc, nothing)
10971101
wait(istate.reschedule)
10981102
@static if VERSION >= v"1.9"
@@ -1101,22 +1105,27 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
11011105
timespan_finish(ctx, :proc_run_wait, to_proc, nothing)
11021106

11031107
# Fetch a new task to execute
1108+
@dagdebug nothing :processor "Trying to dequeue"
11041109
timespan_start(ctx, :proc_run_fetch, to_proc, nothing)
11051110
task_and_occupancy = lock(istate.queue) do queue
11061111
# Only steal if there are multiple queued tasks, to prevent
11071112
# ping-pong of tasks between empty queues
11081113
if length(queue) == 0
1114+
@dagdebug nothing :processor "Nothing to dequeue"
11091115
return nothing
11101116
end
11111117
_, occupancy = peek(queue)
1112-
if proc_has_occupancy(proc_occupancy[], occupancy)
1113-
return dequeue_pair!(queue)
1118+
if !proc_has_occupancy(proc_occupancy[], occupancy)
1119+
@dagdebug nothing :processor "Insufficient occupancy" proc_occupancy=proc_occupancy[] task_occupancy=occupancy
1120+
return nothing
11141121
end
1115-
return nothing
1122+
return dequeue_pair!(queue)
11161123
end
11171124
if task_and_occupancy === nothing
11181125
timespan_finish(ctx, :proc_run_fetch, to_proc, nothing)
11191126

1127+
@dagdebug nothing :processor "Failed to dequeue"
1128+
11201129
if !stealing_permitted(to_proc)
11211130
continue
11221131
end
@@ -1125,6 +1134,8 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
11251134
continue
11261135
end
11271136

1137+
@dagdebug nothing :processor "Trying to steal"
1138+
11281139
# Try to steal a task
11291140
timespan_start(ctx, :steal_local, to_proc, nothing)
11301141

@@ -1159,7 +1170,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
11591170
if task_and_occupancy !== nothing
11601171
from_proc = other_istate.proc
11611172
thunk_id = task[1]
1162-
@dagdebug thunk_id :execute "Stolen from $from_proc by $to_proc"
1173+
@dagdebug thunk_id :processor "Stolen from $from_proc by $to_proc"
11631174
timespan_finish(ctx, :steal_local, to_proc, (;from_proc, thunk_id))
11641175
# TODO: Keep stealing until we hit full occupancy?
11651176
@goto execute
@@ -1177,6 +1188,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
11771188
thunk_id = task[1]
11781189
time_util = task[2]
11791190
timespan_finish(ctx, :proc_run_fetch, to_proc, (;thunk_id, proc_occupancy=proc_occupancy[], task_occupancy))
1191+
@dagdebug thunk_id :processor "Dequeued task"
11801192

11811193
# Execute the task and return its result
11821194
t = @task begin
@@ -1234,10 +1246,12 @@ Executes a batch of tasks on `to_proc`, returning their results through
12341246
`return_queue`.
12351247
"""
12361248
function do_tasks(to_proc, return_queue, tasks)
1249+
@dagdebug nothing :processor "Enqueuing task batch" batch_size=length(tasks)
1250+
12371251
# FIXME: This is terrible
1238-
ctx_vars = first(tasks)[15]
1252+
ctx_vars = first(tasks)[16]
12391253
ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile)
1240-
uid = first(tasks)[17]
1254+
uid = first(tasks)[18]
12411255
state = proc_states(uid) do states
12421256
get!(states, to_proc) do
12431257
queue = PriorityQueue{Vector{Any}, UInt32}()
@@ -1272,6 +1286,7 @@ function do_tasks(to_proc, return_queue, tasks)
12721286
should_launch || continue
12731287
enqueue!(queue, task, occupancy)
12741288
timespan_finish(ctx, :enqueue, (;to_proc, thunk_id), nothing)
1289+
@dagdebug thunk_id :processor "Enqueued task"
12751290
end
12761291
end
12771292
notify(istate.reschedule)
@@ -1287,6 +1302,7 @@ function do_tasks(to_proc, return_queue, tasks)
12871302
end
12881303
notify(other_istate.reschedule)
12891304
end
1305+
@dagdebug nothing :processor "Kicked processors"
12901306
end
12911307

12921308
"""
@@ -1298,7 +1314,7 @@ function do_task(to_proc, task_desc)
12981314
thunk_id, est_time_util, est_alloc_util, est_occupancy,
12991315
scope, Tf, data,
13001316
send_result, persist, cache, meta,
1301-
options, propagated, ids, ctx_vars, sch_handle, uid = task_desc
1317+
options, propagated, ids, positions, ctx_vars, sch_handle, uid = task_desc
13021318
ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile)
13031319

13041320
from_proc = OSProc()
@@ -1442,6 +1458,16 @@ function do_task(to_proc, task_desc)
14421458
end
14431459
f = popfirst!(fetched)
14441460
@assert !(f isa Chunk) "Failed to unwrap thunk function"
1461+
fetched_args = Any[]
1462+
fetched_kwargs = Pair{Symbol,Any}[]
1463+
for (idx, x) in enumerate(fetched)
1464+
pos = positions[idx]
1465+
if pos === nothing
1466+
push!(fetched_args, x)
1467+
else
1468+
push!(fetched_kwargs, pos => x)
1469+
end
1470+
end
14451471

14461472
#= FIXME: If MaxUtilization, stop processors and wait
14471473
if (est_time_util isa MaxUtilization) && (real_time_util > 0)
@@ -1473,7 +1499,7 @@ function do_task(to_proc, task_desc)
14731499

14741500
res = Dagger.with_options(propagated) do
14751501
# Execute
1476-
execute!(to_proc, f, fetched...)
1502+
execute!(to_proc, f, fetched_args...; fetched_kwargs...)
14771503
end
14781504

14791505
# Check if result is safe to store

src/sch/dynamic.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ function _register_future!(ctx, state, task, tid, (future, id, check)::Tuple{Thu
144144
if t == target
145145
return true
146146
end
147-
for input in t.inputs
147+
for (_, input) in t.inputs
148148
# N.B. Skips expired tasks
149149
input = Dagger.unwrap_weak(input)
150150
istask(input) || continue
@@ -195,13 +195,13 @@ function _get_dag_ids(ctx, state, task, tid, _)
195195
end
196196

197197
"Adds a new Thunk to the DAG."
198-
add_thunk!(f, h::SchedulerHandle, args...; future=nothing, ref=nothing, kwargs...) =
199-
exec!(_add_thunk!, h, f, args, kwargs, future, ref)
200-
function _add_thunk!(ctx, state, task, tid, (f, args, kwargs, future, ref))
198+
add_thunk!(f, h::SchedulerHandle, args...; future=nothing, ref=nothing, options...) =
199+
exec!(_add_thunk!, h, f, args, options, future, ref)
200+
function _add_thunk!(ctx, state, task, tid, (f, args, options, future, ref))
201201
timespan_start(ctx, :add_thunk, tid, 0)
202-
_args = map(arg->arg isa ThunkID ? state.thunk_dict[arg.id] : arg, args)
202+
_args = map(pos_arg->pos_arg[1] => (pos_arg[2] isa ThunkID ? state.thunk_dict[pos_arg[2].id] : pos_arg[2]), args)
203203
GC.@preserve _args begin
204-
thunk = Thunk(f, _args...; kwargs...)
204+
thunk = Thunk(f, _args...; options...)
205205
# Create a `DRef` to `thunk` so that the caller can preserve it
206206
thunk_ref = poolset(thunk; size=64, device=MemPool.CPURAMDevice())
207207
thunk_id = ThunkID(thunk.id, thunk_ref)

0 commit comments

Comments
 (0)