Skip to content

Commit 2110d62

Browse files
authored
Merge pull request #464 from JuliaParallel/jps/chained-dtors
Eagerly free Thunk cached results when unneeded
2 parents 0c123a8 + 3781e53 commit 2110d62

File tree

9 files changed

+131
-37
lines changed

9 files changed

+131
-37
lines changed

Diff for: .buildkite/pipeline.yml

+10
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ steps:
3434
julia_args: "--threads=1"
3535
- JuliaCI/julia-coverage#v1:
3636
codecov: true
37+
- label: Julia 1.10
38+
timeout_in_minutes: 60
39+
<<: *test
40+
plugins:
41+
- JuliaCI/julia#v1:
42+
version: "1.10"
43+
- JuliaCI/julia-test#v1:
44+
julia_args: "--threads=1"
45+
- JuliaCI/julia-coverage#v1:
46+
codecov: true
3747
- label: Julia nightly
3848
timeout_in_minutes: 60
3949
<<: *test

Diff for: Manifest.toml

+16-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.8.5"
44
manifest_format = "2.0"
5-
project_hash = "5333a6c200b6e6add81c46547527f66ddc0dc16c"
5+
project_hash = "8da7911e4788068aaea8c0ef8589d674bce0fb39"
66

77
[[deps.Artifacts]]
88
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
@@ -12,9 +12,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1212

1313
[[deps.ChainRulesCore]]
1414
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
15-
git-tree-sha1 = "b66b8f8e3db5d7835fb8cbe2589ffd1cd456e491"
15+
git-tree-sha1 = "2118cb2765f8197b08e5958cdd17c165427425ee"
1616
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
17-
version = "1.17.0"
17+
version = "1.19.0"
1818

1919
[[deps.ChangesOfVariables]]
2020
deps = ["InverseFunctions", "LinearAlgebra", "Test"]
@@ -24,9 +24,9 @@ version = "0.1.8"
2424

2525
[[deps.Compat]]
2626
deps = ["Dates", "LinearAlgebra", "UUIDs"]
27-
git-tree-sha1 = "8a62af3e248a8c4bad6b32cbbe663ae02275e32c"
27+
git-tree-sha1 = "886826d76ea9e72b35fcd000e535588f7b60f21d"
2828
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
29-
version = "4.10.0"
29+
version = "4.10.1"
3030

3131
[[deps.CompilerSupportLibraries_jll]]
3232
deps = ["Artifacts", "Libdl"]
@@ -100,19 +100,19 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
100100

101101
[[deps.MacroTools]]
102102
deps = ["Markdown", "Random"]
103-
git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48"
103+
git-tree-sha1 = "b211c553c199c111d998ecdaf7623d1b89b69f93"
104104
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
105-
version = "0.5.11"
105+
version = "0.5.12"
106106

107107
[[deps.Markdown]]
108108
deps = ["Base64"]
109109
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
110110

111111
[[deps.MemPool]]
112-
deps = ["DataStructures", "Distributed", "Mmap", "Random", "Serialization", "Sockets"]
113-
git-tree-sha1 = "b9c1a032c3c1310a857c061ce487c632eaa1faa4"
112+
deps = ["DataStructures", "Distributed", "Mmap", "Random", "ScopedValues", "Serialization", "Sockets"]
113+
git-tree-sha1 = "60dd4ac427d39e0b3f15b193845324523ee71c03"
114114
uuid = "f9f48841-c794-520a-933b-121f7ba6ed94"
115-
version = "0.4.4"
115+
version = "0.4.6"
116116

117117
[[deps.Missings]]
118118
deps = ["DataAPI"]
@@ -133,9 +133,9 @@ uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
133133
version = "0.3.20+0"
134134

135135
[[deps.OrderedCollections]]
136-
git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3"
136+
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
137137
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
138-
version = "1.6.2"
138+
version = "1.6.3"
139139

140140
[[deps.PrecompileTools]]
141141
deps = ["Preferences"]
@@ -173,9 +173,9 @@ version = "0.7.0"
173173

174174
[[deps.ScopedValues]]
175175
deps = ["HashArrayMappedTries", "Logging"]
176-
git-tree-sha1 = "e3b5e4ccb1702db2ae9ac2a660d4b6b2a8595742"
176+
git-tree-sha1 = "c27d546a4749c81f70d1fabd604da6aa5054e3d2"
177177
uuid = "7e506255-f358-4e82-b7e4-beb19740aa63"
178-
version = "1.1.0"
178+
version = "1.2.0"
179179

180180
[[deps.Serialization]]
181181
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -189,9 +189,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
189189

190190
[[deps.SortingAlgorithms]]
191191
deps = ["DataStructures"]
192-
git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee"
192+
git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085"
193193
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
194-
version = "1.1.1"
194+
version = "1.2.1"
195195

196196
[[deps.SparseArrays]]
197197
deps = ["LinearAlgebra", "Random"]

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2424
[compat]
2525
DataStructures = "0.18"
2626
MacroTools = "0.5"
27-
MemPool = "0.4.5"
27+
MemPool = "0.4.6"
2828
PrecompileTools = "1.2"
2929
Requires = "1"
3030
ScopedValues = "1.1"

Diff for: src/precompile.jl

+30-8
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,47 @@
33
add_processor_callback!("__cpu_thread_1__") do
44
ThreadProc(1, 1)
55
end
6-
t1 = @spawn 1+1
6+
# FIXME: t1 = @spawn 1+1
7+
t1 = spawn(+, 1, 1)
8+
fetch(t1)
79
t2 = spawn(+, 1, t1)
810
fetch(t2)
9-
spawn() do
10-
Sch.halt!(sch_handle())
11+
12+
# Clean up refs
13+
t1 = nothing; t2 = nothing
14+
state = Sch.EAGER_STATE[]
15+
for i in 1:5
16+
length(state.thunk_dict) == 1 && break
17+
GC.gc()
18+
yield()
1119
end
20+
@assert length(state.thunk_dict) == 1
21+
22+
# Halt scheduler
23+
notify(state.halt)
24+
put!(state.chan, (1, nothing, nothing, (Sch.SchedulerHaltedException(), nothing)))
25+
state = nothing
26+
27+
# Wait for halt
1228
while Sch.EAGER_INIT[]
1329
sleep(0.1)
1430
end
31+
32+
# Final clean-up
1533
Sch.EAGER_CONTEXT[] = nothing
16-
GC.gc()
17-
yield()
34+
GC.gc(); yield()
1835
lock(Sch.ERRORMONITOR_TRACKED) do tracked
19-
if all(t->istaskdone(t) || istaskfailed(t), tracked)
36+
if all(t->istaskdone(t) || istaskfailed(t), map(last, tracked))
2037
empty!(tracked)
2138
return
2239
end
23-
for t in tracked
24-
Base.throwto(t, InterruptException())
40+
for (name, t) in tracked
41+
@warn "Waiting on $name"
42+
if t.state == :runnable
43+
Base.throwto(t, InterruptException())
44+
else
45+
wait(t)
46+
end
2547
end
2648
end
2749
MemPool.exit_hook()

Diff for: src/sch/Sch.jl

+41-3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ Fields:
7272
- `lock::ReentrantLock` - Lock around operations which modify the state
7373
- `futures::Dict{Thunk, Vector{ThunkFuture}}` - Futures registered for waiting on the result of a thunk.
7474
- `errored::WeakKeyDict{Thunk,Bool}` - Indicates if a thunk's result is an error.
75+
- `thunks_to_delete::Set{Thunk}` - The list of `Thunk`s ready to be deleted upon completion.
7576
- `chan::RemoteChannel{Channel{Any}}` - Channel for receiving completed thunks.
7677
"""
7778
struct ComputeState
@@ -98,6 +99,7 @@ struct ComputeState
9899
lock::ReentrantLock
99100
futures::Dict{Thunk, Vector{ThunkFuture}}
100101
errored::WeakKeyDict{Thunk,Bool}
102+
thunks_to_delete::Set{Thunk}
101103
chan::RemoteChannel{Channel{Any}}
102104
end
103105

@@ -127,6 +129,7 @@ function start_state(deps::Dict, node_order, chan)
127129
ReentrantLock(),
128130
Dict{Thunk, Vector{ThunkFuture}}(),
129131
WeakKeyDict{Thunk,Bool}(),
132+
Set{Thunk}(),
130133
chan)
131134

132135
for k in sort(collect(keys(deps)), by=node_order)
@@ -366,7 +369,7 @@ function init_proc(state, p, log_sink)
366369
end
367370
end
368371
end
369-
errormonitor_tracked(t)
372+
errormonitor_tracked("worker monitor $wid", t)
370373
WORKER_MONITOR_TASKS[wid] = t
371374
WORKER_MONITOR_CHANS[wid] = Dict{UInt64,RemoteChannel}()
372375
end
@@ -590,6 +593,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options)
590593
timespan_start(ctx, :finish, thunk_id, (;thunk_id))
591594
finish_task!(ctx, state, node, thunk_failed)
592595
timespan_finish(ctx, :finish, thunk_id, (;thunk_id))
596+
597+
delete_unused_tasks!(state)
593598
end
594599

595600
safepoint(state)
@@ -931,6 +936,39 @@ function finish_task!(ctx, state, node, thunk_failed)
931936
evict_all_chunks!(ctx, to_evict)
932937
end
933938

939+
function delete_unused_tasks!(state)
940+
to_delete = Thunk[]
941+
for thunk in state.thunks_to_delete
942+
if task_unused(state, thunk)
943+
# Finished and nobody waiting on us, we can be deleted
944+
push!(to_delete, thunk)
945+
end
946+
end
947+
for thunk in to_delete
948+
# Delete all cached data
949+
task_delete!(state, thunk)
950+
951+
pop!(state.thunks_to_delete, thunk)
952+
end
953+
end
954+
function delete_unused_task!(state, thunk)
955+
if task_unused(state, thunk)
956+
# Will not be accessed further, delete all cached data
957+
task_delete!(state, thunk)
958+
return true
959+
else
960+
return false
961+
end
962+
end
963+
task_unused(state, thunk) =
964+
haskey(state.cache, thunk) && !haskey(state.waiting_data, thunk)
965+
function task_delete!(state, thunk)
966+
delete!(state.cache, thunk)
967+
delete!(state.errored, thunk)
968+
delete!(state.valid, thunk)
969+
delete!(state.thunk_dict, thunk.id)
970+
end
971+
934972
function evict_all_chunks!(ctx, to_evict)
935973
if !isempty(to_evict)
936974
@sync for w in map(p->p.pid, procs_to_use(ctx))
@@ -1290,7 +1328,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
12901328
else
12911329
t.sticky = false
12921330
end
1293-
tasks[thunk_id] = errormonitor_tracked(schedule(t))
1331+
tasks[thunk_id] = errormonitor_tracked("thunk $thunk_id", schedule(t))
12941332
proc_occupancy[] += task_occupancy
12951333
time_pressure[] += time_util
12961334
end
@@ -1302,7 +1340,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13021340
else
13031341
proc_run_task.sticky = false
13041342
end
1305-
return errormonitor_tracked(schedule(proc_run_task))
1343+
return errormonitor_tracked("processor $to_proc", schedule(proc_run_task))
13061344
end
13071345

13081346
"""

Diff for: src/sch/dynamic.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ function dynamic_listener!(ctx, state, wid)
8787
end
8888
end
8989
end
90-
errormonitor_tracked(listener_task)
91-
errormonitor_tracked(@async begin
90+
errormonitor_tracked("dynamic_listener! $wid", listener_task)
91+
errormonitor_tracked("dynamic_listener! (halt+throw) $wid", @async begin
9292
wait(state.halt)
9393
# TODO: Not sure why we need the @async here, but otherwise we
9494
# don't stop all the listener tasks

Diff for: src/sch/eager.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function init_eager()
2020
return
2121
end
2222
ctx = eager_context()
23-
errormonitor_tracked(Threads.@spawn try
23+
errormonitor_tracked("eager compute()", Threads.@spawn try
2424
sopts = SchedulerOptions(;allow_errors=true)
2525
opts = Dagger.Options((;scope=Dagger.ExactScope(Dagger.ThreadProc(1, 1)),
2626
occupancy=Dict(Dagger.ThreadProc=>0)))
@@ -101,7 +101,7 @@ function thunk_yield(f)
101101
end
102102

103103
eager_cleanup(t::Dagger.EagerThunkFinalizer) =
104-
errormonitor_tracked(Threads.@spawn eager_cleanup(EAGER_STATE[], t.uid))
104+
errormonitor_tracked("eager_cleanup $(t.uid)", Threads.@spawn eager_cleanup(EAGER_STATE[], t.uid))
105105
function eager_cleanup(state, uid)
106106
tid = nothing
107107
lock(EAGER_ID_MAP) do id_map

Diff for: src/sch/util.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"Like `errormonitor`, but tracks how many outstanding tasks are running."
2-
function errormonitor_tracked(t::Task)
2+
function errormonitor_tracked(name::String, t::Task)
33
errormonitor(t)
44
@safe_lock_spin1 ERRORMONITOR_TRACKED tracked begin
5-
push!(tracked, t)
5+
push!(tracked, name => t)
66
end
77
errormonitor(Threads.@spawn begin
88
try
99
wait(t)
1010
finally
1111
lock(ERRORMONITOR_TRACKED) do tracked
12-
idx = findfirst(o->o===t, tracked)
12+
idx = findfirst(o->o[2]===t, tracked)
1313
# N.B. This may be nothing if precompile emptied these
1414
if idx !== nothing
1515
deleteat!(tracked, idx)
@@ -18,7 +18,7 @@ function errormonitor_tracked(t::Task)
1818
end
1919
end)
2020
end
21-
const ERRORMONITOR_TRACKED = LockedObject(Task[])
21+
const ERRORMONITOR_TRACKED = LockedObject(Pair{String,Task}[])
2222

2323
"""
2424
unwrap_nested_exception(err::Exception) -> Bool

Diff for: src/submission.jl

+25-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ function eager_submit_internal!(ctx, state, task, tid, payload; uid_to_tid=Dict{
8989
thunk = Thunk(f, args...; options...)
9090

9191
# Create a `DRef` to `thunk` so that the caller can preserve it
92-
thunk_ref = poolset(thunk; size=64, device=MemPool.CPURAMDevice())
92+
thunk_ref = poolset(thunk; size=64, device=MemPool.CPURAMDevice(),
93+
destructor=UnrefThunkByUser(thunk))
9394
thunk_id = Sch.ThunkID(thunk.id, thunk_ref)
9495

9596
# Attach `thunk` within the scheduler
@@ -122,6 +123,29 @@ function eager_submit_internal!(ctx, state, task, tid, payload; uid_to_tid=Dict{
122123
return thunk_id
123124
end
124125
end
126+
struct UnrefThunkByUser
127+
thunk::Thunk
128+
end
129+
function (unref::UnrefThunkByUser)()
130+
Sch.errormonitor_tracked("unref thunk $(unref.thunk.id)", Threads.@spawn begin
131+
# This thunk is no longer referenced by the user, mark it as ready to be
132+
# cleaned up as eagerly as possible (or do so now)
133+
thunk = unref.thunk
134+
state = Sch.EAGER_STATE[]
135+
if state === nothing
136+
return
137+
end
138+
139+
@lock state.lock begin
140+
if !Sch.delete_unused_task!(state, thunk)
141+
# Register for deletion upon thunk completion
142+
push!(state.thunks_to_delete, thunk)
143+
end
144+
# TODO: On success, walk down to children, as a fast-path
145+
end
146+
end)
147+
end
148+
125149

126150
# Local -> Remote
127151
function eager_submit!(ntasks, uid, future, finalizer_ref, f, args, options)

0 commit comments

Comments
 (0)