1
+ using Graphs
2
+
3
+ export In, Out, InOut, spawn_datadeps
4
+
5
+ struct In{T}
6
+ x:: T
7
+ end
8
+ struct Out{T}
9
+ x:: T
10
+ end
11
+ struct InOut{T}
12
+ x:: T
13
+ end
14
+ const AnyInOut = Union{In,Out,InOut}
15
+
16
+ struct DataDepsTaskQueue <: AbstractTaskQueue
17
+ # The queue above us
18
+ upper_queue:: AbstractTaskQueue
19
+ # The mapping of unique objects to previously-launched tasks,
20
+ # and their data dependency on the object (read, write)
21
+ deps:: IdDict {Any, Vector{Pair{Tuple{Bool,Bool}, EagerThunk}}}
22
+ # Whether to analyze the DAG statically or eagerly
23
+ static:: Bool
24
+ # If static=true, the set of tasks that have already been seen
25
+ seen_tasks:: Union{Vector{Pair{EagerTaskSpec,EagerThunk}},Nothing}
26
+ # If static=true, the data-dependency graph of all tasks
27
+ g:: Union{SimpleDiGraph{Int},Nothing}
28
+ # If static=true, the mapping from task to graph ID
29
+ task_to_id:: Union{Dict{EagerThunk,Int},Nothing}
30
+ function DataDepsTaskQueue (upper_queue; static:: Bool = false )
31
+ deps = IdDict{Any, Vector{Pair{Tuple{Bool,Bool}, EagerThunk}}}()
32
+ if static
33
+ seen_tasks = Pair{EagerTaskSpec,EagerThunk}[]
34
+ g = SimpleDiGraph ()
35
+ task_to_id = Dict {EagerThunk,Int} ()
36
+ else
37
+ seen_tasks = nothing
38
+ g = nothing
39
+ task_to_id = nothing
40
+ end
41
+ return new (upper_queue, deps, static, seen_tasks, g, task_to_id)
42
+ end
43
+ end
44
+
45
+ function enqueue! (queue:: DataDepsTaskQueue , fullspec:: Pair{EagerTaskSpec,EagerThunk} )
46
+ # If static, record this task and its edges in the graph
47
+ if queue. static
48
+ g = queue. g
49
+ task_to_id = queue. task_to_id
50
+ end
51
+
52
+ spec, task = fullspec
53
+ if queue. static
54
+ add_vertex! (g)
55
+ task_to_id[task] = our_task_id = nv (g)
56
+ end
57
+ opts = spec. options
58
+ syncdeps = get (Set{Any}, opts, :syncdeps )
59
+ deps_to_add = Vector {Pair{Any, Tuple{Bool,Bool}}} ()
60
+ for (idx, (pos, arg)) in enumerate (spec. args)
61
+ readdep = false
62
+ writedep = false
63
+ if arg isa In
64
+ readdep = true
65
+ arg = arg. x
66
+ elseif arg isa Out
67
+ writedep = true
68
+ arg = arg. x
69
+ elseif arg isa InOut
70
+ readdep = true
71
+ writedep = true
72
+ arg = arg. x
73
+ else
74
+ readdep = true
75
+ end
76
+ spec. args[idx] = pos => arg
77
+
78
+ push! (deps_to_add, arg => (readdep, writedep))
79
+
80
+ if ! haskey (queue. deps, arg)
81
+ continue
82
+ end
83
+ argdeps = queue. deps[arg]:: Vector{Pair{Tuple{Bool,Bool}, EagerThunk}}
84
+ if readdep
85
+ # When you have an in dependency, sync with the previous out
86
+ for ((other_readdep:: Bool , other_writedep:: Bool ),
87
+ other_task:: EagerThunk ) in argdeps
88
+ if other_writedep
89
+ push! (syncdeps, other_task)
90
+ if queue. static
91
+ other_task_id = task_to_id[other_task]
92
+ add_edge! (g, other_task_id, our_task_id)
93
+ end
94
+ end
95
+ end
96
+ end
97
+ if writedep
98
+ # When you have an out depdendency, sync with the previous in or out
99
+ for ((other_readdep:: Bool , other_writedep:: Bool ),
100
+ other_task:: EagerThunk ) in argdeps
101
+ if other_readdep || other_writedep
102
+ push! (syncdeps, other_task)
103
+ if queue. static
104
+ other_task_id = task_to_id[other_task]
105
+ add_edge! (g, other_task_id, our_task_id)
106
+ end
107
+ end
108
+ end
109
+ end
110
+ end
111
+ for (arg, (readdep, writedep)) in deps_to_add
112
+ argdeps = get! (queue. deps, arg) do
113
+ Vector {Pair{Tuple{Bool,Bool}, EagerThunk}} ()
114
+ end
115
+ push! (argdeps, (readdep, writedep) => task)
116
+ end
117
+
118
+ spec. options = merge (opts, (;syncdeps,))
119
+
120
+ if queue. static
121
+ push! (queue. seen_tasks, fullspec)
122
+ else
123
+ enqueue! (queue. upper_queue, fullspec)
124
+ end
125
+ end
126
+ function enqueue! (queue:: DataDepsTaskQueue , specs:: Vector{Pair{EagerTaskSpec,EagerThunk}} )
127
+ # FIXME : Don't register as previous tasks until very end
128
+ error (" Not yet implemented" )
129
+ for spec in specs
130
+ enqueue! (queue, spec)
131
+ end
132
+ end
133
+
134
+ function distribute_tasks! (queue:: DataDepsTaskQueue )
135
+ # "Distributes" the graph by making cuts
136
+ #= TODO : We currently assume:
137
+ # - All data is local to this worker
138
+ # - All data is the same size
139
+ # - All tasks take the same amount of time to execute
140
+ # - Tasks executing on other workers will have data moved for them
141
+ # - All data will be updated locally at the end of the computation
142
+ =#
143
+ # FIXME : Don't do round-robin
144
+ # FIXME : Skip this if only one proc
145
+ all_procs = Processor[]
146
+ for w in procs ()
147
+ append! (all_procs, get_processors (OSProc (w)))
148
+ end
149
+ data_locality = IdDict {Any,Int} (data=> myid () for data in keys (queue. deps))
150
+
151
+ # Make a copy of each piece of data on each worker
152
+ remote_args = Dict {Int,IdDict{Any,Any}} (w=> IdDict {Any,Any} () for w in procs ())
153
+ # FIXME : Owner can repeat (same arg twice to one task)
154
+ args_owner = IdDict {Any,Any} (arg=> nothing for arg in keys (queue. deps))
155
+ for w in procs ()
156
+ for data in keys (queue. deps)
157
+ data isa Array || continue
158
+ if w == myid ()
159
+ remote_args[w][data] = data
160
+ else
161
+ # TODO : Can't use @mutable with custom Chunk scope
162
+ # remote_args[w][data] = Dagger.@mutable worker=w copy(data)
163
+ remote_args[w][data] = remotecall_fetch (Dagger. tochunk, w, data)
164
+ end
165
+ end
166
+ end
167
+
168
+ # Round-robin assign tasks to processors
169
+ proc_idx = 1
170
+ for (spec, task) in queue. seen_tasks
171
+ our_proc = all_procs[proc_idx]
172
+ our_proc_worker = root_worker_id (our_proc)
173
+
174
+ # Spawn copies before and after user's task, as necessary
175
+ @dagdebug nothing :spawn_datadeps " Scheduling $(spec. f) "
176
+ task_queue = get_options (:task_queue )
177
+ task_syncdeps = Set ()
178
+ task_args = copy (spec. args)
179
+
180
+ # Copy args from local to remote
181
+ for (idx, (pos, arg)) in enumerate (task_args)
182
+ arg isa Array || continue
183
+ data_worker = 1
184
+ # TODO : Track initial data locality:
185
+ # data_worker = data_locality[arg]
186
+ if our_proc_worker != data_worker
187
+ # Add copy-to operation (depends on latest owner of arg)
188
+ @dagdebug nothing :spawn_datadeps " Enqueueing copy-to: $data_worker => $our_proc_worker "
189
+ arg_local = remote_args[data_worker][arg]
190
+ @assert arg_local === spec. args[idx][2 ]
191
+ arg_remote = remote_args[our_proc_worker][arg]
192
+ copy_to_scope = scope (worker= our_proc_worker)
193
+ copy_to_syncdeps = Set ()
194
+ if (owner = args_owner[arg]) != = nothing
195
+ @dagdebug nothing :spawn_datadeps " (copy-to arg) Depending on previous owner"
196
+ push! (copy_to_syncdeps, owner)
197
+ end
198
+ copy_to = Dagger. @spawn scope= copy_to_scope syncdeps= copy_to_syncdeps copyto! (arg_remote, arg_local)
199
+ push! (task_syncdeps, copy_to)
200
+ spec. args[idx] = pos => arg_remote
201
+ # TODO : Allow changing data locality:
202
+ # data_locality[arg] = our_proc_worker
203
+ else
204
+ if (owner = args_owner[arg]) != = nothing
205
+ @dagdebug nothing :spawn_datadeps " (local arg) Depending on previous owner"
206
+ push! (task_syncdeps, owner)
207
+ end
208
+ end
209
+ end
210
+
211
+ # Launch user's task
212
+ syncdeps = get (Set, spec. options, :syncdeps )
213
+ for other_task in task_syncdeps
214
+ push! (syncdeps, other_task)
215
+ end
216
+ task_scope = scope (worker= our_proc_worker)
217
+ spec. options = merge (spec. options, (;syncdeps, scope= task_scope))
218
+ enqueue! (task_queue, spec=> task)
219
+ for (_, arg) in task_args
220
+ arg isa Array || continue
221
+ args_owner[arg] = task
222
+ end
223
+
224
+ # Copy args from remote to local
225
+ # TODO : Don't always copy to-and-from
226
+ for (_, arg) in task_args
227
+ arg isa Array || continue
228
+ data_worker = 1
229
+ # TODO : Track initial data locality:
230
+ # data_worker = data_locality[arg]
231
+ if our_proc_worker != data_worker
232
+ # Add copy-from operation
233
+ @dagdebug nothing :spawn_datadeps " Enqueueing copy-from: $our_proc_worker => $data_worker "
234
+ arg_local = remote_args[data_worker][arg]
235
+ arg_remote = remote_args[our_proc_worker][arg]
236
+ copy_from_scope = scope (worker= data_worker)
237
+ copy_from_syncdeps = Set ([task])
238
+ copy_from = Dagger. @spawn scope= copy_from_scope syncdeps= copy_from_syncdeps copyto! (arg_local, arg_remote)
239
+
240
+ # Set copy-from as latest owner of arg
241
+ args_owner[arg] = copy_from
242
+
243
+ # TODO : Allow changing data locality:
244
+ # data_locality[arg] = our_proc_worker
245
+ end
246
+ end
247
+ proc_idx = mod1 (proc_idx+ 1 , length (all_procs))
248
+ end
249
+ end
250
+
251
+ function spawn_datadeps (f:: Base.Callable ; static:: Bool = false )
252
+ queue = DataDepsTaskQueue (get_options (:task_queue , EagerTaskQueue ()); static)
253
+ result = with_options (f; task_queue= queue)
254
+ if queue. static
255
+ distribute_tasks! (queue)
256
+ end
257
+ return result
258
+ end
259
+
260
+ # FIXME : Move this elsewhere
261
+ struct WaitAllQueue <: AbstractTaskQueue
262
+ upper_queue:: AbstractTaskQueue
263
+ tasks:: Vector{EagerThunk}
264
+ end
265
+ function enqueue! (queue:: WaitAllQueue , spec:: Pair{EagerTaskSpec,EagerThunk} )
266
+ push! (queue. tasks, spec[2 ])
267
+ enqueue! (queue. upper_queue, spec)
268
+ end
269
+ function enqueue! (queue:: WaitAllQueue , specs:: Vector{Pair{EagerTaskSpec,EagerThunk}} )
270
+ for (_, task) in specs
271
+ push! (queue. tasks, task)
272
+ end
273
+ enqueue! (queue. upper_queue, specs)
274
+ end
275
+ function wait_all (f)
276
+ queue = WaitAllQueue (get_options (:task_queue , EagerTaskQueue ()), EagerThunk[])
277
+ result = with_options (f; task_queue= queue)
278
+ for task in queue. tasks
279
+ fetch (task)
280
+ end
281
+ return result
282
+ end
0 commit comments