Skip to content

Commit 9f0ac3f

Browse files
authored
Merge pull request #568 from JuliaParallel/streaming-migration
Initial support for robustly migrating streaming tasks
2 parents 4013384 + 0a86a70 commit 9f0ac3f

17 files changed

+1152
-423
lines changed

Diff for: src/Dagger.jl

+20-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ else
2525
end
2626
import TaskLocalValues: TaskLocalValue
2727

28+
import TaskLocalValues: TaskLocalValue
29+
2830
if !isdefined(Base, :get_extension)
2931
import Requires: @require
3032
end
@@ -47,16 +49,16 @@ include("processor.jl")
4749
include("threadproc.jl")
4850
include("context.jl")
4951
include("utils/processors.jl")
52+
include("dtask.jl")
53+
include("cancellation.jl")
5054
include("task-tls.jl")
5155
include("scopes.jl")
5256
include("utils/scopes.jl")
53-
include("dtask.jl")
5457
include("queue.jl")
5558
include("thunk.jl")
5659
include("submission.jl")
5760
include("chunks.jl")
5861
include("memory-spaces.jl")
59-
include("cancellation.jl")
6062

6163
# Task scheduling
6264
include("compute.jl")
@@ -69,9 +71,9 @@ include("sch/Sch.jl"); using .Sch
6971
include("datadeps.jl")
7072

7173
# Streaming
72-
include("stream-buffers.jl")
73-
include("stream-fetchers.jl")
7474
include("stream.jl")
75+
include("stream-buffers.jl")
76+
include("stream-transfer.jl")
7577

7678
# Array computations
7779
include("array/darray.jl")
@@ -152,6 +154,20 @@ function __init__()
152154
ThreadProc(myid(), tid)
153155
end
154156
end
157+
158+
# Set up @dagdebug categories, if specified
159+
try
160+
if haskey(ENV, "JULIA_DAGGER_DEBUG")
161+
empty!(DAGDEBUG_CATEGORIES)
162+
for category in split(ENV["JULIA_DAGGER_DEBUG"], ",")
163+
if category != ""
164+
push!(DAGDEBUG_CATEGORIES, Symbol(category))
165+
end
166+
end
167+
end
168+
catch err
169+
@warn "Error parsing JULIA_DAGGER_DEBUG" exception=err
170+
end
155171
end
156172

157173
end # module

Diff for: src/array/indexing.jl

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import TaskLocalValues: TaskLocalValue
2-
31
### getindex
42

53
struct GetIndex{T,N} <: ArrayOp{T,N}

Diff for: src/cancellation.jl

+52-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,51 @@
1+
# DTask-level cancellation
2+
3+
mutable struct CancelToken
4+
@atomic cancelled::Bool
5+
@atomic graceful::Bool
6+
event::Base.Event
7+
end
8+
CancelToken() = CancelToken(false, false, Base.Event())
9+
function cancel!(token::CancelToken; graceful::Bool=true)
10+
if !graceful
11+
@atomic token.graceful = false
12+
end
13+
@atomic token.cancelled = true
14+
notify(token.event)
15+
return
16+
end
17+
function is_cancelled(token::CancelToken; must_force::Bool=false)
18+
if token.cancelled[]
19+
if must_force && token.graceful[]
20+
# If we're only responding to forced cancellation, ignore graceful cancellations
21+
return false
22+
end
23+
return true
24+
end
25+
return false
26+
end
27+
Base.wait(token::CancelToken) = wait(token.event)
28+
# TODO: Enable this for safety
29+
#Serialization.serialize(io::AbstractSerializer, ::CancelToken) =
30+
# throw(ConcurrencyViolationError("Cannot serialize a CancelToken"))
31+
32+
const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing)
33+
34+
function clone_cancel_token_remote(orig_token::CancelToken, wid::Integer)
35+
remote_token = remotecall_fetch(wid) do
36+
return poolset(CancelToken())
37+
end
38+
errormonitor_tracked("remote cancel_token communicator", Threads.@spawn begin
39+
wait(orig_token)
40+
@dagdebug nothing :cancel "Cancelling remote token on worker $wid"
41+
MemPool.access_ref(remote_token) do remote_token
42+
cancel!(remote_token)
43+
end
44+
end)
45+
end
46+
47+
# Global-level cancellation
48+
149
"""
250
cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false)
351
@@ -48,7 +96,7 @@ function _cancel!(state, tid, force, halt_sch)
4896
for task in state.ready
4997
tid !== nothing && task.id != tid && continue
5098
@dagdebug tid :cancel "Cancelling ready task"
51-
state.cache[task] = InterruptException()
99+
state.cache[task] = DTaskFailedException(task, task, InterruptException())
52100
state.errored[task] = true
53101
Sch.set_failed!(state, task)
54102
end
@@ -58,7 +106,7 @@ function _cancel!(state, tid, force, halt_sch)
58106
for task in keys(state.waiting)
59107
tid !== nothing && task.id != tid && continue
60108
@dagdebug tid :cancel "Cancelling waiting task"
61-
state.cache[task] = InterruptException()
109+
state.cache[task] = DTaskFailedException(task, task, InterruptException())
62110
state.errored[task] = true
63111
Sch.set_failed!(state, task)
64112
end
@@ -80,11 +128,11 @@ function _cancel!(state, tid, force, halt_sch)
80128
Tf === typeof(Sch.eager_thunk) && continue
81129
istaskdone(task) && continue
82130
any_cancelled = true
83-
@dagdebug tid :cancel "Cancelling running task ($Tf)"
84131
if force
85132
@dagdebug tid :cancel "Interrupting running task ($Tf)"
86133
Threads.@spawn Base.throwto(task, InterruptException())
87134
else
135+
@dagdebug tid :cancel "Cancelling running task ($Tf)"
88136
# Tell the processor to just drop this task
89137
task_occupancy = task_spec[4]
90138
time_util = task_spec[2]
@@ -93,6 +141,7 @@ function _cancel!(state, tid, force, halt_sch)
93141
push!(istate.cancelled, tid)
94142
to_proc = istate.proc
95143
put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing)))
144+
cancel!(istate.cancel_tokens[tid]; graceful=false)
96145
end
97146
end
98147
end

Diff for: src/options.jl

+6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ function with_options(f, options::NamedTuple)
2020
end
2121
with_options(f; options...) = with_options(f, NamedTuple(options))
2222

23+
function _without_options(f)
24+
with(options_context => NamedTuple()) do
25+
f()
26+
end
27+
end
28+
2329
"""
2430
get_options(key::Symbol, default) -> Any
2531
get_options(key::Symbol) -> Any

Diff for: src/sch/Sch.jl

+15-1
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,7 @@ struct ProcessorInternalState
11701170
proc_occupancy::Base.RefValue{UInt32}
11711171
time_pressure::Base.RefValue{UInt64}
11721172
cancelled::Set{Int}
1173+
cancel_tokens::Dict{Int,Dagger.CancelToken}
11731174
done::Base.RefValue{Bool}
11741175
end
11751176
struct ProcessorState
@@ -1189,7 +1190,7 @@ function proc_states(f::Base.Callable, uid::UInt64)
11891190
end
11901191
end
11911192
proc_states(f::Base.Callable) =
1192-
proc_states(f, task_local_storage(:_dagger_sch_uid)::UInt64)
1193+
proc_states(f, Dagger.get_tls().sch_uid)
11931194

11941195
task_tid_for_processor(::Processor) = nothing
11951196
task_tid_for_processor(proc::Dagger.ThreadProc) = proc.tid
@@ -1318,7 +1319,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13181319

13191320
# Execute the task and return its result
13201321
t = @task begin
1322+
# Set up cancellation
1323+
cancel_token = Dagger.CancelToken()
1324+
Dagger.DTASK_CANCEL_TOKEN[] = cancel_token
1325+
lock(istate.queue) do _
1326+
istate.cancel_tokens[thunk_id] = cancel_token
1327+
end
13211328
was_cancelled = false
1329+
13221330
result = try
13231331
do_task(to_proc, task)
13241332
catch err
@@ -1335,6 +1343,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13351343
# Task was cancelled, so occupancy and pressure are
13361344
# already reduced
13371345
pop!(istate.cancelled, thunk_id)
1346+
delete!(istate.cancel_tokens, thunk_id)
13381347
was_cancelled = true
13391348
end
13401349
end
@@ -1352,6 +1361,9 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13521361
else
13531362
rethrow(err)
13541363
end
1364+
finally
1365+
# Ensure that any spawned tasks get cleaned up
1366+
Dagger.cancel!(cancel_token)
13551367
end
13561368
end
13571369
lock(istate.queue) do _
@@ -1401,6 +1413,7 @@ function do_tasks(to_proc, return_queue, tasks)
14011413
Dict{Int,Vector{Any}}(),
14021414
Ref(UInt32(0)), Ref(UInt64(0)),
14031415
Set{Int}(),
1416+
Dict{Int,Dagger.CancelToken}(),
14041417
Ref(false))
14051418
runner = start_processor_runner!(istate, uid, return_queue)
14061419
@static if VERSION < v"1.9"
@@ -1640,6 +1653,7 @@ function do_task(to_proc, task_desc)
16401653
sch_handle,
16411654
processor=to_proc,
16421655
task_spec=task_desc,
1656+
cancel_token=Dagger.DTASK_CANCEL_TOKEN[],
16431657
))
16441658

16451659
res = Dagger.with_options(propagated) do

Diff for: src/sch/dynamic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct SchedulerHandle
1717
end
1818

1919
"Gets the scheduler handle for the currently-executing thunk."
20-
sch_handle() = task_local_storage(:_dagger_sch_handle)::SchedulerHandle
20+
sch_handle() = Dagger.get_tls().sch_handle::SchedulerHandle
2121

2222
"Thrown when the scheduler halts before finishing processing the DAG."
2323
struct SchedulerHaltedException <: Exception end

Diff for: src/sch/eager.jl

+3
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,6 @@ function _find_thunk(e::Dagger.DTask)
141141
unwrap_weak_checked(EAGER_STATE[].thunk_dict[tid])
142142
end
143143
end
144+
Dagger.task_id(t::Dagger.DTask) = lock(EAGER_ID_MAP) do id_map
145+
id_map[t.uid]
146+
end

Diff for: src/sch/util.jl

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ unwrap_nested_exception(err::CapturedException) =
2929
unwrap_nested_exception(err.ex)
3030
unwrap_nested_exception(err::RemoteException) =
3131
unwrap_nested_exception(err.captured)
32+
unwrap_nested_exception(err::DTaskFailedException) =
33+
unwrap_nested_exception(err.ex)
3234
unwrap_nested_exception(err) = err
3335

3436
"Gets a `NamedTuple` of options propagated by `thunk`."

0 commit comments

Comments
 (0)