Skip to content

Commit b9e22d1

Browse files
committed
options: Add dispatch-based options
1 parent 64a87d8 commit b9e22d1

File tree

5 files changed

+120
-91
lines changed

5 files changed

+120
-91
lines changed

Diff for: Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
88
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
99
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1112
MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"
1213
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Diff for: src/Dagger.jl

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using UUIDs
1515
import ContextVariablesX
1616

1717
using Requires
18+
using MacroTools
1819

1920
const PLUGINS = Dict{Symbol,Any}()
2021
const PLUGIN_CONFIGS = Dict{Symbol,String}(

Diff for: src/options.jl

+45
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,48 @@ function get_options(key::Symbol, default)
1515
opts = get_options()
1616
return haskey(opts, key) ? getproperty(opts, key) : default
1717
end
18+
19+
# Dispatch-based setters
20+
21+
"""
22+
default_option(::Val{name}, Tf, Targs...) where name = value
23+
24+
Defines the default value for option `name` to `value` when Dagger is preparing
25+
to execute a function with type `Tf` with the argument types `Targs`. Users and
26+
libraries may override this to set default values for tasks.
27+
28+
An easier way to define these defaults is with [`@option`](@ref).
29+
30+
Note that the actual task's argument values are not passed, as it may not
31+
always be possible or efficient to gather all Dagger task arguments on one
32+
worker.
33+
34+
This function may be executed within the scheduler, so it should generally be
35+
made very cheap to execute. If the function throws an error, the scheduler will
36+
use whatever the global default value is for that option instead.
37+
"""
38+
default_option(::Val{name}, Tf, Targs...) where name = nothing
39+
default_option(::Val) = throw(ArgumentError("default_option requires a function type and any argument types"))
40+
41+
"""
42+
@option name myfunc(A, B, C) = value
43+
44+
A convenience macro for defining [`default_option`](@ref). For example:
45+
46+
```julia
47+
Dagger.@option single mylocalfunc(Int) = 1
48+
```
49+
50+
The above call will set the `single` option to `1` for any Dagger task calling
51+
`mylocalfunc(Int)` with an `Int` argument.
52+
"""
53+
macro option(name, ex)
54+
@capture(ex, f_(args__) = value_)
55+
args = esc.(args)
56+
argsyms = map(_->gensym(), args)
57+
_args = map(arg->:(::$Type{$(argsyms[arg[1]])}), enumerate(args))
58+
argsubs = map(arg->:($(argsyms[arg[1]])<:$(arg[2])), enumerate(args))
59+
quote
60+
Dagger.default_option(::$Val{$name}, ::Type{$typeof($(esc(f)))}, $(_args...)) where {$(argsubs...)} = $(esc(value))
61+
end
62+
end

Diff for: src/sch/Sch.jl

+53-66
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,13 @@ If this returns a `Chunk`, all thunks will be skipped, and the `Chunk` will be
156156
returned. If `nothing` is returned, restoring is skipped, and the scheduler
157157
will execute as usual. If this function throws an error, restoring will be
158158
skipped, and the error will be displayed.
159-
- `round_robin::Bool=false`: Whether to schedule in round-robin mode, which
160-
spreads load instead of the default behavior of filling processors to capacity.
161159
"""
162160
Base.@kwdef struct SchedulerOptions
163-
single::Int = 0
161+
single::Union{Int,Nothing} = nothing
164162
proclist = nothing
165-
allow_errors::Bool = false
163+
allow_errors::Union{Bool,Nothing} = false
166164
checkpoint = nothing
167165
restore = nothing
168-
round_robin::Bool = false
169166
end
170167

171168
"""
@@ -209,11 +206,11 @@ device must support `MemPool.CPURAMResource`. When `nothing`, uses
209206
`MemPool.GLOBAL_DEVICE[]`.
210207
"""
211208
Base.@kwdef struct ThunkOptions
212-
single::Int = 0
209+
single::Union{Int,Nothing} = nothing
213210
proclist = nothing
214-
time_util::Dict{Type,Any} = Dict{Type,Any}()
215-
alloc_util::Dict{Type,UInt64} = Dict{Type,UInt64}()
216-
allow_errors::Bool = false
211+
time_util::Union{Dict{Type,Any},Nothing} = nothing
212+
alloc_util::Union{Dict{Type,UInt64},Nothing} = nothing
213+
allow_errors::Union{Bool,Nothing} = nothing
217214
checkpoint = nothing
218215
restore = nothing
219216
storage::Union{Chunk,Nothing} = nothing
@@ -228,20 +225,50 @@ include("eager.jl")
228225
Combine `SchedulerOptions` and `ThunkOptions` into a new `ThunkOptions`.
229226
"""
230227
function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions)
231-
single = topts.single != 0 ? topts.single : sopts.single
232-
allow_errors = sopts.allow_errors || topts.allow_errors
228+
single = topts.single !== nothing ? topts.single : sopts.single
229+
allow_errors = topts.allow_errors !== nothing ? topts.allow_errors : sopts.allow_errors
233230
proclist = topts.proclist !== nothing ? topts.proclist : sopts.proclist
234-
ThunkOptions(single, proclist, topts.time_util, topts.alloc_util, allow_errors, topts.checkpoint, topts.restore, topts.storage)
231+
ThunkOptions(single,
232+
proclist,
233+
topts.time_util,
234+
topts.alloc_util,
235+
allow_errors,
236+
topts.checkpoint,
237+
topts.restore,
238+
topts.storage)
235239
end
236240
Base.merge(sopts::SchedulerOptions, ::Nothing) =
237-
ThunkOptions(sopts.single, sopts.proclist, Dict{Type,Any}(), sopts.allow_errors)
241+
ThunkOptions(sopts.single,
242+
sopts.proclist,
243+
nothing,
244+
nothing,
245+
sopts.allow_errors)
246+
"""
247+
populate_defaults(opts::ThunkOptions, Tf, Targs) -> ThunkOptions
238248
239-
function isrestricted(task::Thunk, proc::OSProc)
240-
if (task.options !== nothing) && (task.options.single != 0) &&
241-
(task.options.single != proc.pid)
242-
return true
249+
Returns a `ThunkOptions` with default values filled in for a function of type
250+
`Tf` with argument types `Targs`, if the option was previously unspecified in
251+
`opts`.
252+
"""
253+
function populate_defaults(opts::ThunkOptions, Tf, Targs)
254+
function maybe_default(opt::Symbol)
255+
old_opt = getproperty(opts, opt)
256+
if old_opt !== nothing
257+
return old_opt
258+
else
259+
return Dagger.default_option(Val(opt), Tf, Targs...)
260+
end
243261
end
244-
return false
262+
ThunkOptions(
263+
maybe_default(:single),
264+
maybe_default(:proclist),
265+
maybe_default(:time_util),
266+
maybe_default(:alloc_util),
267+
maybe_default(:allow_errors),
268+
maybe_default(:checkpoint),
269+
maybe_default(:restore),
270+
maybe_default(:storage),
271+
)
245272
end
246273

247274
function cleanup(ctx)
@@ -470,7 +497,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options)
470497
timespan_finish(ctx, :handle_fault, 0, 0)
471498
return # effectively `continue`
472499
else
473-
if ctx.options.allow_errors || unwrap_weak_checked(state.thunk_dict[thunk_id]).options.allow_errors
500+
if something(ctx.options.allow_errors, false) ||
501+
something(unwrap_weak_checked(state.thunk_dict[thunk_id]).options.allow_errors, false)
474502
thunk_failed = true
475503
else
476504
throw(res)
@@ -537,7 +565,7 @@ function scheduler_exit(ctx, state::ComputeState, options)
537565
end
538566

539567
function procs_to_use(ctx, options=ctx.options)
540-
return if options.single !== 0
568+
return if options.single !== nothing
541569
@assert options.single in vcat(1, workers()) "Sch option `single` must specify an active worker ID."
542570
OSProc[OSProc(options.single)]
543571
else
@@ -617,7 +645,9 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
617645
@goto fallback
618646
end
619647

620-
local_procs, costs = estimate_task_costs(state, local_procs, task)
648+
inputs = collect_task_inputs(state, task)
649+
opts = populate_defaults(opts, chunktype(task.f), map(chunktype, inputs))
650+
local_procs, costs = estimate_task_costs(state, local_procs, task, inputs)
621651
scheduled = false
622652

623653
# Move our corresponding ThreadProc to be the last considered
@@ -703,9 +733,6 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
703733
push!(get!(()->Vector{Tuple{Thunk,<:Any,<:Any}}(), to_fire, (gproc, proc)), (task, est_time_util, est_alloc_util))
704734

705735
# Proceed to next entry to spread work
706-
if !ctx.options.round_robin
707-
@warn "Round-robin mode is always on"
708-
end
709736
state.procs_cache_list[] = state.procs_cache_list[].next
710737
@goto pop_task
711738

@@ -777,46 +804,6 @@ function remove_dead_proc!(ctx, state, proc, options=ctx.options)
777804
state.procs_cache_list[] = nothing
778805
end
779806

780-
function pop_with_affinity!(ctx, tasks, proc)
781-
# TODO: use the size
782-
parent_affinity_procs = Vector(undef, length(tasks))
783-
# parent_affinity_sizes = Vector(undef, length(tasks))
784-
for i=length(tasks):-1:1
785-
t = tasks[i]
786-
aff = affinity(t)
787-
aff_procs = first.(aff)
788-
if proc in aff_procs
789-
if !isrestricted(t,proc)
790-
deleteat!(tasks, i)
791-
return t
792-
end
793-
end
794-
parent_affinity_procs[i] = aff_procs
795-
end
796-
for i=length(tasks):-1:1
797-
# use up tasks without affinities
798-
# let the procs with the respective affinities pick up
799-
# other tasks
800-
aff_procs = parent_affinity_procs[i]
801-
if isempty(aff_procs)
802-
t = tasks[i]
803-
if !isrestricted(t,proc)
804-
deleteat!(tasks, i)
805-
return t
806-
end
807-
end
808-
if all(!(p in aff_procs) for p in procs(ctx))
809-
# no proc is ever going to ask for it
810-
t = tasks[i]
811-
if !isrestricted(t,proc)
812-
deleteat!(tasks, i)
813-
return t
814-
end
815-
end
816-
end
817-
return nothing
818-
end
819-
820807
function finish_task!(ctx, state, node, thunk_failed)
821808
pop!(state.running, node)
822809
delete!(state.running_on, node)
@@ -909,12 +896,12 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)
909896
toptions = thunk.options !== nothing ? thunk.options : ThunkOptions()
910897
options = merge(ctx.options, toptions)
911898
propagated = get_propagated_options(thunk)
912-
@assert (options.single == 0) || (gproc.pid == options.single)
899+
@assert (options.single === nothing) || (gproc.pid == options.single)
913900
# TODO: Set `sch_handle.tid.ref` to the right `DRef`
914901
sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...)
915902

916903
# TODO: De-dup common fields (log_sink, uid, etc.)
917-
push!(to_send, Any[thunk.id, time_util, alloc_util, fn_type(thunk.f), data, thunk.get_result,
904+
push!(to_send, Any[thunk.id, time_util, alloc_util, chunktype(thunk.f), data, thunk.get_result,
918905
thunk.persist, thunk.cache, thunk.meta, options,
919906
propagated, ids,
920907
(log_sink=ctx.log_sink, profile=ctx.profile),

Diff for: src/sch/util.jl

+20-25
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,10 @@ function report_catch_error(err, desc=nothing)
236236
write(stderr, iob)
237237
end
238238

239-
fn_type(x::Chunk) = x.chunktype
240-
fn_type(x) = typeof(x)
239+
chunktype(x) = typeof(x)
241240
function signature(task::Thunk, state)
242-
sig = Any[fn_type(task.f)]
243-
for input in task.inputs
244-
input = unwrap_weak_checked(input)
245-
input = istask(input) ? state.cache[input] : input
246-
push!(sig, fn_type(input))
247-
end
241+
sig = Any[chunktype(task.f)]
242+
append!(sig, collect_task_inputs(state, task))
248243
sig
249244
end
250245

@@ -270,7 +265,7 @@ function can_use_proc(task, gproc, proc, opts, scope)
270265
end
271266

272267
# Check against single
273-
if opts.single != 0
268+
if opts.single !== nothing
274269
if gproc.pid != opts.single
275270
@debug "Rejected $proc: gproc.pid != single"
276271
return false
@@ -290,7 +285,7 @@ end
290285
function has_capacity(state, p, gp, time_util, alloc_util, sig)
291286
T = typeof(p)
292287
# FIXME: MaxUtilization
293-
est_time_util = round(UInt64, if haskey(time_util, T)
288+
est_time_util = round(UInt64, if time_util !== nothing && haskey(time_util, T)
294289
time_util[T] * 1000^3
295290
else
296291
get(state.signature_time_cost, sig, 1000^3)
@@ -300,7 +295,9 @@ function has_capacity(state, p, gp, time_util, alloc_util, sig)
300295
real_alloc_util = state.worker_storage_pressure[gp][storage]
301296
real_alloc_cap = state.worker_storage_capacity[gp][storage]
302297
=#
303-
est_alloc_util = get(alloc_util, T) do
298+
est_alloc_util = if alloc_util !== nothing && haskey(alloc_util, T)
299+
alloc_util[T]
300+
else
304301
get(state.signature_alloc_cost, sig, 0)
305302
end
306303
#= FIXME
@@ -352,29 +349,27 @@ function impute_sum(xs)
352349
total + nothing_count * total / something_count
353350
end
354351

352+
"Collects all arguments for `task`, converting Thunk inputs to Chunks."
353+
function collect_task_inputs(state, task)
354+
inputs = Any[]
355+
for input in task.inputs
356+
input = unwrap_weak_checked(input)
357+
push!(inputs, istask(input) ? state.cache[input] : input)
358+
end
359+
inputs
360+
end
361+
355362
"""
356363
Estimates the cost of scheduling `task` on each processor in `procs`. Considers
357364
current estimated per-processor compute pressure, and transfer costs for each
358365
`Chunk` argument to `task`. Returns `(procs, costs)`, with `procs` sorted in
359366
order of ascending cost.
360367
"""
361-
function estimate_task_costs(state, procs, task)
368+
function estimate_task_costs(state, procs, task, inputs)
362369
tx_rate = state.transfer_rate[]
363370

364371
# Find all Chunks
365-
chunks = Chunk[]
366-
for input in task.inputs
367-
input = unwrap_weak_checked(input)
368-
input_raw = istask(input) ? state.cache[input] : input
369-
if input_raw isa Chunk
370-
push!(chunks, input_raw)
371-
end
372-
end
373-
#=
374-
inputs = map(@nospecialize(input)->istask(input) ? state.cache[input] : input,
375-
map(@nospecialize(x)->unwrap_weak_checked(x), task.inputs))
376-
chunks = filter(@nospecialize(t)->isa(t, Chunk), inputs)
377-
=#
372+
chunks = filter(t->isa(t, Chunk), inputs)
378373

379374
# Estimate network transfer costs based on data size
380375
# N.B. `affinity(x)` really means "data size of `x`"

0 commit comments

Comments
 (0)