Skip to content

Commit d30e4a9

Browse files
committed
Sch: Fix incorrect task signature calculation
1 parent 69bcd77 commit d30e4a9

File tree

2 files changed

+85
-55
lines changed

2 files changed

+85
-55
lines changed

Diff for: src/sch/util.jl

+24-9
Original file line numberDiff line numberDiff line change
@@ -285,15 +285,30 @@ function report_catch_error(err, desc=nothing)
285285
end
286286

287287
chunktype(x) = typeof(x)
288-
signature(state, task::Thunk) = signature(state, task.f, task.inputs)
289-
function signature(state, f, inputs::Vector)
290-
sig = Any[chunktype(f)]
291-
for (pos, input) in collect_task_inputs(state, inputs)
292-
# N.B. Skips kwargs
288+
signature(state, task::Thunk) =
289+
signature(task.f, collect_task_inputs(state, task.inputs))
290+
function signature(f, args)
291+
sig = DataType[chunktype(f)]
292+
sig_kwarg_names = Symbol[]
293+
sig_kwarg_types = []
294+
for (pos, arg) in args
295+
if arg isa Dagger.DTask
296+
# Only occurs via manual usage of signature
297+
arg = fetch(arg; raw=true)
298+
end
299+
T = chunktype(arg)
293300
if pos === nothing
294-
push!(sig, chunktype(input))
301+
push!(sig, T)
302+
else
303+
push!(sig_kwarg_names, pos)
304+
push!(sig_kwarg_types, T)
295305
end
296306
end
307+
if !isempty(sig_kwarg_names)
308+
NT = NamedTuple{(sig_kwarg_names...,), Base.to_tuple_type(sig_kwarg_types)}
309+
pushfirst!(sig, NT)
310+
pushfirst!(sig, typeof(Core.kwcall))
311+
end
297312
return sig
298313
end
299314

@@ -423,12 +438,12 @@ end
423438
collect_task_inputs(state, task::Thunk) =
424439
collect_task_inputs(state, task.inputs)
425440
function collect_task_inputs(state, inputs)
426-
inputs = Pair{Union{Symbol,Nothing},Any}[]
441+
new_inputs = Pair{Union{Symbol,Nothing},Any}[]
427442
for (pos, input) in inputs
428443
input = unwrap_weak_checked(input)
429-
push!(inputs, pos => (istask(input) ? state.cache[input] : input))
444+
push!(new_inputs, pos => (istask(input) ? state.cache[input] : input))
430445
end
431-
return inputs
446+
return new_inputs
432447
end
433448

434449
"""

Diff for: test/scheduler.jl

+61-46
Original file line numberDiff line numberDiff line change
@@ -350,59 +350,74 @@ end
350350
end
351351

352352
@testset "Scheduler algorithms" begin
353-
# New function to hide from scheduler's function cost cache
354-
mynothing(args...) = nothing
355-
356-
# New non-singleton struct to hide from `approx_size`
357-
struct MyStruct
358-
x::Int
353+
@testset "Signature Calculation" begin
354+
@test Dagger.Sch.signature(+, [nothing=>1, nothing=>2]) isa Vector{DataType}
355+
@test Dagger.Sch.signature(+, [nothing=>1, nothing=>2]) == [typeof(+), Int, Int]
356+
@test Dagger.Sch.signature(+, [nothing=>1, :a=>2]) == [typeof(Core.kwcall), @NamedTuple{a::Int64}, typeof(+), Int]
357+
@test Dagger.Sch.signature(+, []) == [typeof(+)]
358+
@test Dagger.Sch.signature(+, [nothing=>1]) == [typeof(+), Int]
359+
360+
c = Dagger.tochunk(1.0)
361+
@test Dagger.Sch.signature(*, [nothing=>c, nothing=>3]) == [typeof(*), Float64, Int]
362+
t = Dagger.@spawn 1+2
363+
@test Dagger.Sch.signature(/, [nothing=>t, nothing=>c, nothing=>3]) == [typeof(/), Int, Float64, Int]
359364
end
360365

361-
state = Dagger.Sch.EAGER_STATE[]
362-
tproc1 = Dagger.ThreadProc(1, 1)
363-
tproc2 = Dagger.ThreadProc(first(workers()), 1)
364-
procs = [tproc1, tproc2]
365-
366-
pres1 = state.worker_time_pressure[1][tproc1]
367-
pres2 = state.worker_time_pressure[first(workers())][tproc2]
368-
tx_rate = state.transfer_rate[]
369-
370-
for (args, tx_size) in [
371-
([1, 2], 0),
372-
([Dagger.tochunk(1), 2], sizeof(Int)),
373-
([1, Dagger.tochunk(2)], sizeof(Int)),
374-
([Dagger.tochunk(1), Dagger.tochunk(2)], 2*sizeof(Int)),
375-
# TODO: Why does this work? Seems slow
376-
([Dagger.tochunk(MyStruct(1))], sizeof(MyStruct)),
377-
([Dagger.tochunk(MyStruct(1)), Dagger.tochunk(1)], sizeof(MyStruct)+sizeof(Int)),
378-
]
379-
for arg in args
380-
if arg isa Chunk
381-
aff = Dagger.affinity(arg)
382-
@test aff[1] == OSProc(1)
383-
@test aff[2] == MemPool.approx_size(MemPool.poolget(arg.handle))
384-
end
366+
@testset "Cost Estimation" begin
367+
# New function to hide from scheduler's function cost cache
368+
mynothing(args...) = nothing
369+
370+
# New non-singleton struct to hide from `approx_size`
371+
struct MyStruct
372+
x::Int
385373
end
386374

387-
cargs = map(arg->MemPool.poolget(arg.handle), filter(arg->isa(arg, Chunk), args))
388-
est_tx_size = Dagger.Sch.impute_sum(map(MemPool.approx_size, cargs))
389-
@test est_tx_size == tx_size
375+
state = Dagger.Sch.EAGER_STATE[]
376+
tproc1 = Dagger.ThreadProc(1, 1)
377+
tproc2 = Dagger.ThreadProc(first(workers()), 1)
378+
procs = [tproc1, tproc2]
379+
380+
pres1 = state.worker_time_pressure[1][tproc1]
381+
pres2 = state.worker_time_pressure[first(workers())][tproc2]
382+
tx_rate = state.transfer_rate[]
383+
384+
for (args, tx_size) in [
385+
([1, 2], 0),
386+
([Dagger.tochunk(1), 2], sizeof(Int)),
387+
([1, Dagger.tochunk(2)], sizeof(Int)),
388+
([Dagger.tochunk(1), Dagger.tochunk(2)], 2*sizeof(Int)),
389+
# TODO: Why does this work? Seems slow
390+
([Dagger.tochunk(MyStruct(1))], sizeof(MyStruct)),
391+
([Dagger.tochunk(MyStruct(1)), Dagger.tochunk(1)], sizeof(MyStruct)+sizeof(Int)),
392+
]
393+
for arg in args
394+
if arg isa Chunk
395+
aff = Dagger.affinity(arg)
396+
@test aff[1] == OSProc(1)
397+
@test aff[2] == MemPool.approx_size(MemPool.poolget(arg.handle))
398+
end
399+
end
400+
401+
cargs = map(arg->MemPool.poolget(arg.handle), filter(arg->isa(arg, Chunk), args))
402+
est_tx_size = Dagger.Sch.impute_sum(map(MemPool.approx_size, cargs))
403+
@test est_tx_size == tx_size
390404

391-
t = delayed(mynothing)(args...)
392-
inputs = Dagger.Sch.collect_task_inputs(state, t)
393-
sorted_procs, costs = Dagger.Sch.estimate_task_costs(state, procs, t, inputs)
405+
t = delayed(mynothing)(args...)
406+
inputs = Dagger.Sch.collect_task_inputs(state, t)
407+
sorted_procs, costs = Dagger.Sch.estimate_task_costs(state, procs, t, inputs)
394408

395-
@test tproc1 in sorted_procs
396-
@test tproc2 in sorted_procs
397-
if length(cargs) > 0
398-
@test sorted_procs[1] == tproc1
399-
@test sorted_procs[2] == tproc2
400-
end
409+
@test tproc1 in sorted_procs
410+
@test tproc2 in sorted_procs
411+
if length(cargs) > 0
412+
@test sorted_procs[1] == tproc1
413+
@test sorted_procs[2] == tproc2
414+
end
401415

402-
@test haskey(costs, tproc1)
403-
@test haskey(costs, tproc2)
404-
@test costs[tproc1] pres1 # All chunks are local
405-
@test costs[tproc2] (tx_size/tx_rate) + pres2 # All chunks are remote
416+
@test haskey(costs, tproc1)
417+
@test haskey(costs, tproc2)
418+
@test costs[tproc1] pres1 # All chunks are local
419+
@test costs[tproc2] (tx_size/tx_rate) + pres2 # All chunks are remote
420+
end
406421
end
407422
end
408423

0 commit comments

Comments
 (0)