Skip to content

Commit e3251bc

Browse files
authored
Merge pull request #61 from bicycle1885/refactor-compute
Refactor compute.jl
2 parents dce6f2c + 397bb58 commit e3251bc

File tree

1 file changed

+16
-34
lines changed

1 file changed

+16
-34
lines changed

Diff for: src/compute.jl

+16-34
Original file line numberDiff line numberDiff line change
@@ -148,26 +148,25 @@ function thunkize(ctx, c::Cat)
148148
end
149149
thunkize(ctx, x::AbstractChunk) = x
150150
thunkize(ctx, x::Thunk) = x
151+
151152
function finish_task!(state, node, node_order; free=true)
152-
deps = sort([i for i in state[:dependents][node]], by=node_order)
153-
immediate_next = false
154153
if istask(node) && node.cache
155154
node.cache_ref = Nullable{Any}(state[:cache][node])
156155
end
157-
for dep in deps
156+
immediate_next = false
157+
for dep in sort!(collect(state[:dependents][node]), by=node_order)
158158
set = state[:waiting][dep]
159159
pop!(set, node)
160160
if isempty(set)
161161
pop!(state[:waiting], dep)
162-
immediate_next = true
163162
push!(state[:ready], dep)
163+
immediate_next = true
164164
end
165165
# todo: free data
166166
end
167167
for inp in inputs(node)
168168
if inp in keys(state[:waiting_data])
169169
s = state[:waiting_data][inp]
170-
#@show s
171170
if node in s
172171
pop!(s, node)
173172
end
@@ -197,12 +196,7 @@ function compute(ctx, d::Thunk)
197196
ps = procs(ctx)
198197
chan = Channel{Any}(32)
199198
deps = dependents(d)
200-
ndeps = noffspring(deps)
201-
ord = order(d, ndeps)
202-
203-
sort_ord = collect(ord)
204-
sortord = x -> istask(x[1]) ? x[1].id : 0
205-
sort_ord = sort(sort_ord, by=sortord)
199+
ord = order(d, noffspring(deps))
206200

207201
node_order = x -> -get(ord, x, 0)
208202
state = start_state(deps, node_order)
@@ -218,21 +212,16 @@ function compute(ctx, d::Thunk)
218212

219213
while !isempty(state[:waiting]) || !isempty(state[:ready]) || !isempty(state[:running])
220214
proc, thunk_id, res = take!(chan)
221-
222215
if isa(res, CapturedException) || isa(res, RemoteException)
223216
rethrow(res)
224217
end
225218
node = _thunk_dict[thunk_id]
226219
@logmsg("W$(proc.pid) - $node ($(node.f)) input:$(node.inputs)")
227220
state[:cache][node] = res
228-
#@show state[:cache]
229-
#@show ord
230-
# if any of this guy's dependents are waiting,
231-
# update them
232-
@dbg timespan_start(ctx, :scheduler, thunk_id, master)
233221

222+
# if any of this guy's dependents are waiting, update them
223+
@dbg timespan_start(ctx, :scheduler, thunk_id, master)
234224
immediate_next = finish_task!(state, node, node_order)
235-
236225
if !isempty(state[:ready])
237226
if immediate_next
238227
# fast path
@@ -248,7 +237,6 @@ function compute(ctx, d::Thunk)
248237
end
249238
@dbg timespan_end(ctx, :scheduler, thunk_id, master)
250239
end
251-
252240
state[:cache][d]
253241
end
254242

@@ -342,7 +330,6 @@ function fire_task!(ctx, thunk, proc, state, chan, node_order)
342330
data = map(thunk.inputs) do x
343331
istask(x) ? state[:cache][x] : x
344332
end
345-
346333
async_apply(ctx, proc, thunk.id, thunk.f, data, chan, thunk.get_result, thunk.persist)
347334
end
348335

@@ -364,8 +351,6 @@ function dependents(node::Thunk, deps=Dict())
364351
deps
365352
end
366353

367-
368-
369354
"""
370355
recursively find the number of taks dependent on each task in the DAG.
371356
Input: dependents dict
@@ -383,7 +368,6 @@ function noffspring(n, dpents)
383368
end
384369
end
385370

386-
387371
"""
388372
Given a root node of the DAG, calculates a total order for tie-braking
389373
@@ -397,18 +381,16 @@ Args:
397381
- ndeps: result of `noffspring`
398382
"""
399383
function order(node::Thunk, ndeps)
400-
order([node], ndeps, 0)[2]
401-
end
402-
403-
function order(nodes::AbstractArray, ndeps, c, output=Dict())
404-
405-
for node in nodes
406-
c+=1
407-
output[node] = c
408-
nxt = sort(Any[n for n in inputs(node)], by=k->get(ndeps,k,0))
409-
c, output = order(nxt, ndeps, c, output)
384+
function recur(nodes, s)
385+
for n in nodes
386+
output[n] = s += 1
387+
s = recur(sort!(collect(Any, inputs(n)), by=k->get(ndeps,k,0)), s)
388+
end
389+
return s
410390
end
411-
c, output
391+
output = Dict{Any,Int}()
392+
recur([node], 0)
393+
return output
412394
end
413395

414396
function start_state(deps::Dict, node_order)

0 commit comments

Comments
 (0)