Skip to content

Commit fcef8ac

Browse files
committed
Add initial distributed support
1 parent b169984 commit fcef8ac

File tree

4 files changed

+284
-95
lines changed

4 files changed

+284
-95
lines changed

Diff for: Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.18.4"
55
[deps]
66
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
77
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
8+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1011
MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"

Diff for: src/Dagger.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ include("utils/caching.jl")
4242
include("sch/Sch.jl"); using .Sch
4343

4444
# Data dependency task queue
45-
include("datadep.jl")
45+
include("datadeps.jl")
4646

4747
# Array computations
4848
include("array/darray.jl")

Diff for: src/datadep.jl

-94
This file was deleted.

Diff for: src/datadeps.jl

+282
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
using Graphs
2+
3+
export In, Out, InOut, spawn_datadeps
4+
5+
struct In{T}
6+
x::T
7+
end
8+
struct Out{T}
9+
x::T
10+
end
11+
struct InOut{T}
12+
x::T
13+
end
14+
const AnyInOut = Union{In,Out,InOut}
15+
16+
struct DataDepsTaskQueue <: AbstractTaskQueue
17+
# The queue above us
18+
upper_queue::AbstractTaskQueue
19+
# The mapping of unique objects to previously-launched tasks,
20+
# and their data dependency on the object (read, write)
21+
deps::IdDict{Any, Vector{Pair{Tuple{Bool,Bool}, EagerThunk}}}
22+
# Whether to analyze the DAG statically or eagerly
23+
static::Bool
24+
# If static=true, the set of tasks that have already been seen
25+
seen_tasks::Union{Vector{Pair{EagerTaskSpec,EagerThunk}},Nothing}
26+
# If static=true, the data-dependency graph of all tasks
27+
g::Union{SimpleDiGraph{Int},Nothing}
28+
# If static=true, the mapping from task to graph ID
29+
task_to_id::Union{Dict{EagerThunk,Int},Nothing}
30+
function DataDepsTaskQueue(upper_queue; static::Bool=false)
31+
deps = IdDict{Any, Vector{Pair{Tuple{Bool,Bool}, EagerThunk}}}()
32+
if static
33+
seen_tasks = Pair{EagerTaskSpec,EagerThunk}[]
34+
g = SimpleDiGraph()
35+
task_to_id = Dict{EagerThunk,Int}()
36+
else
37+
seen_tasks = nothing
38+
g = nothing
39+
task_to_id = nothing
40+
end
41+
return new(upper_queue, deps, static, seen_tasks, g, task_to_id)
42+
end
43+
end
44+
45+
function enqueue!(queue::DataDepsTaskQueue, fullspec::Pair{EagerTaskSpec,EagerThunk})
46+
# If static, record this task and its edges in the graph
47+
if queue.static
48+
g = queue.g
49+
task_to_id = queue.task_to_id
50+
end
51+
52+
spec, task = fullspec
53+
if queue.static
54+
add_vertex!(g)
55+
task_to_id[task] = our_task_id = nv(g)
56+
end
57+
opts = spec.options
58+
syncdeps = get(Set{Any}, opts, :syncdeps)
59+
deps_to_add = Vector{Pair{Any, Tuple{Bool,Bool}}}()
60+
for (idx, (pos, arg)) in enumerate(spec.args)
61+
readdep = false
62+
writedep = false
63+
if arg isa In
64+
readdep = true
65+
arg = arg.x
66+
elseif arg isa Out
67+
writedep = true
68+
arg = arg.x
69+
elseif arg isa InOut
70+
readdep = true
71+
writedep = true
72+
arg = arg.x
73+
else
74+
readdep = true
75+
end
76+
spec.args[idx] = pos => arg
77+
78+
push!(deps_to_add, arg => (readdep, writedep))
79+
80+
if !haskey(queue.deps, arg)
81+
continue
82+
end
83+
argdeps = queue.deps[arg]::Vector{Pair{Tuple{Bool,Bool}, EagerThunk}}
84+
if readdep
85+
# When you have an in dependency, sync with the previous out
86+
for ((other_readdep::Bool, other_writedep::Bool),
87+
other_task::EagerThunk) in argdeps
88+
if other_writedep
89+
push!(syncdeps, other_task)
90+
if queue.static
91+
other_task_id = task_to_id[other_task]
92+
add_edge!(g, other_task_id, our_task_id)
93+
end
94+
end
95+
end
96+
end
97+
if writedep
98+
# When you have an out depdendency, sync with the previous in or out
99+
for ((other_readdep::Bool, other_writedep::Bool),
100+
other_task::EagerThunk) in argdeps
101+
if other_readdep || other_writedep
102+
push!(syncdeps, other_task)
103+
if queue.static
104+
other_task_id = task_to_id[other_task]
105+
add_edge!(g, other_task_id, our_task_id)
106+
end
107+
end
108+
end
109+
end
110+
end
111+
for (arg, (readdep, writedep)) in deps_to_add
112+
argdeps = get!(queue.deps, arg) do
113+
Vector{Pair{Tuple{Bool,Bool}, EagerThunk}}()
114+
end
115+
push!(argdeps, (readdep, writedep) => task)
116+
end
117+
118+
spec.options = merge(opts, (;syncdeps,))
119+
120+
if queue.static
121+
push!(queue.seen_tasks, fullspec)
122+
else
123+
enqueue!(queue.upper_queue, fullspec)
124+
end
125+
end
126+
function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{EagerTaskSpec,EagerThunk}})
127+
# FIXME: Don't register as previous tasks until very end
128+
error("Not yet implemented")
129+
for spec in specs
130+
enqueue!(queue, spec)
131+
end
132+
end
133+
134+
function distribute_tasks!(queue::DataDepsTaskQueue)
135+
# "Distributes" the graph by making cuts
136+
#= TODO: We currently assume:
137+
# - All data is local to this worker
138+
# - All data is the same size
139+
# - All tasks take the same amount of time to execute
140+
# - Tasks executing on other workers will have data moved for them
141+
# - All data will be updated locally at the end of the computation
142+
=#
143+
# FIXME: Don't do round-robin
144+
# FIXME: Skip this if only one proc
145+
all_procs = Processor[]
146+
for w in procs()
147+
append!(all_procs, get_processors(OSProc(w)))
148+
end
149+
data_locality = IdDict{Any,Int}(data=>myid() for data in keys(queue.deps))
150+
151+
# Make a copy of each piece of data on each worker
152+
remote_args = Dict{Int,IdDict{Any,Any}}(w=>IdDict{Any,Any}() for w in procs())
153+
# FIXME: Owner can repeat (same arg twice to one task)
154+
args_owner = IdDict{Any,Any}(arg=>nothing for arg in keys(queue.deps))
155+
for w in procs()
156+
for data in keys(queue.deps)
157+
data isa Array || continue
158+
if w == myid()
159+
remote_args[w][data] = data
160+
else
161+
# TODO: Can't use @mutable with custom Chunk scope
162+
#remote_args[w][data] = Dagger.@mutable worker=w copy(data)
163+
remote_args[w][data] = remotecall_fetch(Dagger.tochunk, w, data)
164+
end
165+
end
166+
end
167+
168+
# Round-robin assign tasks to processors
169+
proc_idx = 1
170+
for (spec, task) in queue.seen_tasks
171+
our_proc = all_procs[proc_idx]
172+
our_proc_worker = root_worker_id(our_proc)
173+
174+
# Spawn copies before and after user's task, as necessary
175+
@dagdebug nothing :spawn_datadeps "Scheduling $(spec.f)"
176+
task_queue = get_options(:task_queue)
177+
task_syncdeps = Set()
178+
task_args = copy(spec.args)
179+
180+
# Copy args from local to remote
181+
for (idx, (pos, arg)) in enumerate(task_args)
182+
arg isa Array || continue
183+
data_worker = 1
184+
# TODO: Track initial data locality:
185+
#data_worker = data_locality[arg]
186+
if our_proc_worker != data_worker
187+
# Add copy-to operation (depends on latest owner of arg)
188+
@dagdebug nothing :spawn_datadeps "Enqueueing copy-to: $data_worker => $our_proc_worker"
189+
arg_local = remote_args[data_worker][arg]
190+
@assert arg_local === spec.args[idx][2]
191+
arg_remote = remote_args[our_proc_worker][arg]
192+
copy_to_scope = scope(worker=our_proc_worker)
193+
copy_to_syncdeps = Set()
194+
if (owner = args_owner[arg]) !== nothing
195+
@dagdebug nothing :spawn_datadeps "(copy-to arg) Depending on previous owner"
196+
push!(copy_to_syncdeps, owner)
197+
end
198+
copy_to = Dagger.@spawn scope=copy_to_scope syncdeps=copy_to_syncdeps copyto!(arg_remote, arg_local)
199+
push!(task_syncdeps, copy_to)
200+
spec.args[idx] = pos => arg_remote
201+
# TODO: Allow changing data locality:
202+
#data_locality[arg] = our_proc_worker
203+
else
204+
if (owner = args_owner[arg]) !== nothing
205+
@dagdebug nothing :spawn_datadeps "(local arg) Depending on previous owner"
206+
push!(task_syncdeps, owner)
207+
end
208+
end
209+
end
210+
211+
# Launch user's task
212+
syncdeps = get(Set, spec.options, :syncdeps)
213+
for other_task in task_syncdeps
214+
push!(syncdeps, other_task)
215+
end
216+
task_scope = scope(worker=our_proc_worker)
217+
spec.options = merge(spec.options, (;syncdeps, scope=task_scope))
218+
enqueue!(task_queue, spec=>task)
219+
for (_, arg) in task_args
220+
arg isa Array || continue
221+
args_owner[arg] = task
222+
end
223+
224+
# Copy args from remote to local
225+
# TODO: Don't always copy to-and-from
226+
for (_, arg) in task_args
227+
arg isa Array || continue
228+
data_worker = 1
229+
# TODO: Track initial data locality:
230+
#data_worker = data_locality[arg]
231+
if our_proc_worker != data_worker
232+
# Add copy-from operation
233+
@dagdebug nothing :spawn_datadeps "Enqueueing copy-from: $our_proc_worker => $data_worker"
234+
arg_local = remote_args[data_worker][arg]
235+
arg_remote = remote_args[our_proc_worker][arg]
236+
copy_from_scope = scope(worker=data_worker)
237+
copy_from_syncdeps = Set([task])
238+
copy_from = Dagger.@spawn scope=copy_from_scope syncdeps=copy_from_syncdeps copyto!(arg_local, arg_remote)
239+
240+
# Set copy-from as latest owner of arg
241+
args_owner[arg] = copy_from
242+
243+
# TODO: Allow changing data locality:
244+
#data_locality[arg] = our_proc_worker
245+
end
246+
end
247+
proc_idx = mod1(proc_idx+1, length(all_procs))
248+
end
249+
end
250+
251+
function spawn_datadeps(f::Base.Callable; static::Bool=false)
252+
queue = DataDepsTaskQueue(get_options(:task_queue, EagerTaskQueue()); static)
253+
result = with_options(f; task_queue=queue)
254+
if queue.static
255+
distribute_tasks!(queue)
256+
end
257+
return result
258+
end
259+
260+
# FIXME: Move this elsewhere
261+
struct WaitAllQueue <: AbstractTaskQueue
262+
upper_queue::AbstractTaskQueue
263+
tasks::Vector{EagerThunk}
264+
end
265+
function enqueue!(queue::WaitAllQueue, spec::Pair{EagerTaskSpec,EagerThunk})
266+
push!(queue.tasks, spec[2])
267+
enqueue!(queue.upper_queue, spec)
268+
end
269+
function enqueue!(queue::WaitAllQueue, specs::Vector{Pair{EagerTaskSpec,EagerThunk}})
270+
for (_, task) in specs
271+
push!(queue.tasks, task)
272+
end
273+
enqueue!(queue.upper_queue, specs)
274+
end
275+
function wait_all(f)
276+
queue = WaitAllQueue(get_options(:task_queue, EagerTaskQueue()), EagerThunk[])
277+
result = with_options(f; task_queue=queue)
278+
for task in queue.tasks
279+
fetch(task)
280+
end
281+
return result
282+
end

0 commit comments

Comments
 (0)