Skip to content

Commit 7364106

Browse files
authored
Merge pull request #472 from JuliaParallel/jps/darray-no-cache
DArray: Remove the stage cache
2 parents 8172122 + 0aad01f commit 7364106

12 files changed

+42
-74
lines changed

Diff for: src/array/darray.jl

+4-34
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian()
8080

8181
collect(x::ArrayOp) = collect(fetch(x))
8282

83-
_to_darray(x::ArrayOp) = cached_stage(Context(global_context()), x)::DArray
83+
_to_darray(x::ArrayOp) = stage(Context(global_context()), x)::DArray
8484
Base.fetch(x::ArrayOp) = fetch(_to_darray(x))
8585

8686
collect(x::Computation) = collect(fetch(x))
8787

88-
Base.fetch(x::Computation) = fetch(cached_stage(Context(global_context()), x))
88+
Base.fetch(x::Computation) = fetch(stage(Context(global_context()), x))
8989

9090
function Base.show(io::IO, ::MIME"text/plain", x::ArrayOp)
9191
write(io, string(typeof(x)))
@@ -288,36 +288,6 @@ function Base.fetch(c::DArray{T}) where T
288288
end
289289
end
290290

291-
global _stage_cache = WeakKeyDict{Context, Dict}()
292-
293-
"""
294-
cached_stage(ctx::Context, x)
295-
296-
A memoized version of stage. It is important that the
297-
tasks generated for the same `DArray` have the same
298-
identity, for example:
299-
300-
```julia
301-
A = rand(Blocks(100,100), Float64, 1000, 1000)
302-
compute(A+A')
303-
```
304-
305-
must not result in computation of `A` twice.
306-
"""
307-
function cached_stage(ctx::Context, x)
308-
cache = if !haskey(_stage_cache, ctx)
309-
_stage_cache[ctx] = Dict()
310-
else
311-
_stage_cache[ctx]
312-
end
313-
314-
if haskey(cache, x)
315-
cache[x]
316-
else
317-
cache[x] = stage(ctx, x)
318-
end
319-
end
320-
321291
Base.@deprecate_binding Cat DArray
322292
Base.@deprecate_binding ComputedArray DArray
323293

@@ -352,15 +322,15 @@ end
352322
function stage(ctx::Context, d::Distribute)
353323
if isa(d.data, ArrayOp)
354324
# distributing a distributed array
355-
x = cached_stage(ctx, d.data)
325+
x = stage(ctx, d.data)
356326
if d.domainchunks == domainchunks(x)
357327
return x # already properly distributed
358328
end
359329
Nd = ndims(x)
360330
T = eltype(d.data)
361331
concat = x.concat
362332
cs = map(d.domainchunks) do idx
363-
chunks = cached_stage(ctx, x[idx]).chunks
333+
chunks = stage(ctx, x[idx]).chunks
364334
shape = size(chunks)
365335
# TODO: fix hashing
366336
#hash = uhash(idx, Base.hash(Distribute, Base.hash(d.data)))

Diff for: src/array/getindex.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ GetIndex(input::ArrayOp, idx::Tuple) =
77
GetIndex{eltype(input), ndims(input)}(input, idx)
88

99
function stage(ctx::Context, gidx::GetIndex)
10-
inp = cached_stage(ctx, gidx.input)
10+
inp = stage(ctx, gidx.input)
1111

1212
dmn = domain(inp)
1313
idxs = [if isa(gidx.idx[i], Colon)
@@ -32,7 +32,7 @@ struct GetIndexScalar <: Computation
3232
end
3333

3434
function stage(ctx::Context, gidx::GetIndexScalar)
35-
inp = cached_stage(ctx, gidx.input)
35+
inp = stage(ctx, gidx.input)
3636
s = view(inp, ArrayDomain(gidx.idx))
3737
Dagger.@spawn identity(collect(s)[1])
3838
end

Diff for: src/array/map-reduce.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ size(x::Map) = size(x.inputs[1])
1010
Map(f, inputs::Tuple) = Map{Any, ndims(inputs[1])}(f, inputs)
1111

1212
function stage(ctx::Context, node::Map)
13-
inputs = Any[cached_stage(ctx, n) for n in node.inputs]
13+
inputs = Any[stage(ctx, n) for n in node.inputs]
1414
primary = inputs[1] # all others will align to this guy
1515
domains = domainchunks(primary)
1616
thunks = similar(domains, Any)
@@ -130,7 +130,7 @@ function Base.reduce(f::Function, x::ArrayOp; dims = nothing)
130130
end
131131

132132
function stage(ctx::Context, r::Reducedim)
133-
inp = cached_stage(ctx, r.input)
133+
inp = stage(ctx, r.input)
134134
thunks = let op = r.op, dims=r.dims
135135
# do reducedim on each block
136136
tmp = map(p->Dagger.spawn(b->reduce(op,b,dims=dims), p), chunks(inp))

Diff for: src/array/matrix.jl

+12-12
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function _ctranspose(x::AbstractArray)
4545
end
4646

4747
function stage(ctx::Context, node::Transpose)
48-
inp = cached_stage(ctx, node.input)
48+
inp = stage(ctx, node.input)
4949
thunks = _ctranspose(chunks(inp))
5050
return DArray(eltype(inp), domain(inp)', domainchunks(inp)', thunks, inp.partitioning', inp.concat)
5151
end
@@ -143,7 +143,7 @@ function promote_distribution(ctx::Context, m::MatMul, a,b)
143143
pb = domainchunks(b)
144144

145145
d = DomainBlocks((1,1), (pa.cumlength[2], pb.cumlength[2])) # FIXME: this is not generic
146-
a, cached_stage(ctx, Distribute(d, b))
146+
a, stage(ctx, Distribute(d, b))
147147
end
148148

149149
function stage_operands(ctx::Context, m::MatMul, a, b)
@@ -152,14 +152,14 @@ function stage_operands(ctx::Context, m::MatMul, a, b)
152152
end
153153
# take the row distribution of a and get b onto that.
154154

155-
stg_a = cached_stage(ctx, a)
156-
stg_b = cached_stage(ctx, b)
155+
stg_a = stage(ctx, a)
156+
stg_b = stage(ctx, b)
157157
promote_distribution(ctx, m, stg_a, stg_b)
158158
end
159159

160160
"An operand which should be distributed as per convenience"
161161
function stage_operands(ctx::Context, ::MatMul, a::ArrayOp, b::PromotePartition{T,1}) where T
162-
stg_a = cached_stage(ctx, a)
162+
stg_a = stage(ctx, a)
163163
dmn_a = domain(stg_a)
164164
dchunks_a = domainchunks(stg_a)
165165
dmn_b = domain(b.data)
@@ -168,19 +168,19 @@ function stage_operands(ctx::Context, ::MatMul, a::ArrayOp, b::PromotePartition{
168168
end
169169
dmn_out = DomainBlocks((1,),(dchunks_a.cumlength[2],))
170170

171-
stg_a, cached_stage(ctx, Distribute(dmn_out, b.data))
171+
stg_a, stage(ctx, Distribute(dmn_out, b.data))
172172
end
173173

174174
function stage_operands(ctx::Context, ::MatMul, a::PromotePartition, b::ArrayOp)
175175

176176
if size(a, 2) != size(b, 1)
177177
throw(DimensionMismatch("Cannot promote array of domain $(dmn_b) to multiply with an array of size $(dmn_a)"))
178178
end
179-
stg_b = cached_stage(ctx, b)
179+
stg_b = stage(ctx, b)
180180

181181
ps = domainchunks(stg_b)
182182
dmn_out = DomainBlocks((1,1),([size(a.data, 1)], ps.cumlength[1],))
183-
cached_stage(ctx, Distribute(dmn_out, a.data)), stg_b
183+
stage(ctx, Distribute(dmn_out, a.data)), stg_b
184184
end
185185

186186
function stage(ctx::Context, mul::MatMul{T,N}) where {T,N}
@@ -215,11 +215,11 @@ scale(l::ArrayOp, r::ArrayOp) = _to_darray(Scale(l, r))
215215
function stage_operand(ctx::Context, ::Scale, a, b::PromotePartition)
216216
ps = domainchunks(a)
217217
b_parts = DomainBlocks((1,), (ps.cumlength[1],))
218-
cached_stage(ctx, Distribute(b_parts, b.data))
218+
stage(ctx, Distribute(b_parts, b.data))
219219
end
220220

221221
function stage_operand(ctx::Context, ::Scale, a, b)
222-
cached_stage(ctx, b)
222+
stage(ctx, b)
223223
end
224224

225225
function _scale(l, r)
@@ -231,7 +231,7 @@ function _scale(l, r)
231231
end
232232

233233
function stage(ctx::Context, scal::Scale)
234-
r = cached_stage(ctx, scal.r)
234+
r = stage(ctx, scal.r)
235235
l = stage_operand(ctx, scal, r, scal.l)
236236

237237
@assert size(domain(r), 1) == size(domain(l), 1)
@@ -265,7 +265,7 @@ function Base.cat(d::ArrayDomain, ds::ArrayDomain...; dims::Int)
265265
end
266266

267267
function stage(ctx::Context, c::Concat)
268-
inp = Any[cached_stage(ctx, x) for x in c.inputs]
268+
inp = Any[stage(ctx, x) for x in c.inputs]
269269

270270
dmns = map(domain, inp)
271271
dims = [[i == c.axis ? 0 : i for i in size(d)] for d in dmns]

Diff for: src/array/operators.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ BCast(b::Broadcasted) = BCast{typeof(b), combine_eltypes(b.f, b.args), length(ax
2828
size(x::BCast) = map(length, axes(x.bcasted))
2929

3030
function stage_operands(ctx::Context, ::BCast, xs::ArrayOp...)
31-
map(x->cached_stage(ctx, x), xs)
31+
map(x->stage(ctx, x), xs)
3232
end
3333

3434
function stage_operands(ctx::Context, ::BCast, x::ArrayOp, y::PromotePartition)
35-
stg_x = cached_stage(ctx, x)
35+
stg_x = stage(ctx, x)
3636
y1 = Distribute(domain(stg_x), y.data)
37-
stg_x, cached_stage(ctx, y1)
37+
stg_x, stage(ctx, y1)
3838
end
3939

4040
function stage_operands(ctx::Context, ::BCast, x::PromotePartition, y::ArrayOp)
41-
stg_y = cached_stage(ctx, y)
41+
stg_y = stage(ctx, y)
4242
x1 = Distribute(domain(stg_y), x.data)
43-
cached_stage(ctx, x1), stg_y
43+
stage(ctx, x1), stg_y
4444
end
4545

4646
struct DaggerBroadcastStyle <: BroadcastStyle end
@@ -57,7 +57,7 @@ function stage(ctx::Context, node::BCast{B,T,N}) where {B,T,N}
5757
bc = Broadcast.flatten(node.bcasted)
5858
args = bc.args
5959
args1 = map(args) do x
60-
x isa ArrayOp ? cached_stage(ctx, x) : x
60+
x isa ArrayOp ? stage(ctx, x) : x
6161
end
6262
ds = map(x->x isa DArray ? domainchunks(x) : nothing, args1)
6363
sz = size(node)
@@ -84,7 +84,7 @@ function stage(ctx::Context, node::BCast{B,T,N}) where {B,T,N}
8484
end
8585
end |> Tuple
8686
dmn = DomainBlocks(ntuple(_->1, length(s)), splits)
87-
cached_stage(ctx, Distribute(dmn, part, arg)).chunks
87+
stage(ctx, Distribute(dmn, part, arg)).chunks
8888
else
8989
arg
9090
end
@@ -105,7 +105,7 @@ end
105105
mapchunk(f::Function, xs::ArrayOp...) = MapChunk(f, xs)
106106
Base.@deprecate mappart(args...) mapchunk(args...)
107107
function stage(ctx::Context, node::MapChunk)
108-
inputs = map(x->cached_stage(ctx, x), node.input)
108+
inputs = map(x->stage(ctx, x), node.input)
109109
thunks = map(map(chunks, inputs)...) do ps...
110110
Dagger.spawn(node.f, map(p->nothing=>p, ps)...)
111111
end

Diff for: src/array/setindex.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function setindex(x::ArrayOp, val, idxs...)
1212
end
1313

1414
function stage(ctx::Context, sidx::SetIndex)
15-
inp = cached_stage(ctx, sidx.input)
15+
inp = stage(ctx, sidx.input)
1616

1717
dmn = domain(inp)
1818
idxs = [if isa(sidx.idx[i], Colon)

Diff for: src/compute.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export stage, cached_stage, compute, debug_compute, cleanup
1+
export compute, debug_compute
22

33
###### Scheduler #######
44

Diff for: src/file-io.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ function save(p::Computation, name::AbstractString)
335335
end
336336

337337
function stage(ctx::Context, s::Save)
338-
x = cached_stage(ctx, s.input)
338+
x = stage(ctx, s.input)
339339
dir_path = s.name * "_data"
340340
if !isdir(dir_path)
341341
mkdir(dir_path)

Diff for: src/sch/Sch.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)
10631063
@async begin
10641064
timespan_start(ctx, :fire, gproc.pid, 0)
10651065
try
1066-
remotecall_wait(do_tasks, gproc.pid, proc, state.chan, [ts])
1066+
remotecall_wait(do_tasks, gproc.pid, proc, state.chan, [ts]);
10671067
catch err
10681068
bt = catch_backtrace()
10691069
thunk_id = ts[1]
@@ -1552,7 +1552,7 @@ function do_task(to_proc, task_desc)
15521552
=#
15531553
x = @invokelatest move(to_proc, x)
15541554
#end
1555-
@dagdebug thunk_id :move "Moved argument $id to $to_proc: $x"
1555+
@dagdebug thunk_id :move "Moved argument $id to $to_proc: $(typeof(x))"
15561556
timespan_finish(ctx, :move, (;thunk_id, id), (;f, id, data=x); tasks=[Base.current_task()])
15571557
return x
15581558
end
@@ -1595,7 +1595,7 @@ function do_task(to_proc, task_desc)
15951595
# FIXME
15961596
#gcnum_start = Base.gc_num()
15971597

1598-
@dagdebug thunk_id :execute "Executing"
1598+
@dagdebug thunk_id :execute "Executing $(typeof(f))"
15991599

16001600
result_meta = try
16011601
# Set TLS variables

Diff for: src/sch/util.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ function print_sch_status(io::IO, state, thunk; offset=0, limit=5, max_inputs=3)
228228
println(io, "$(thunk.id): $(thunk.f)")
229229
for (idx, input) in enumerate(thunk.syncdeps)
230230
if input isa WeakThunk
231-
input = unwrap_weak(input)
231+
input = Dagger.unwrap_weak(input)
232232
if input === nothing
233233
println(io, repeat(' ', offset+2), "(???)")
234234
continue

Diff for: src/threadproc.jl

+6-8
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,25 @@ iscompatible_func(proc::ThreadProc, opts, f) = true
1212
iscompatible_arg(proc::ThreadProc, opts, x) = true
1313
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @nospecialize(kwargs...))
1414
tls = get_tls()
15+
# FIXME: Use return type of the call to specialize container
16+
result = Ref{Any}()
1517
task = Task() do
1618
set_tls!(tls)
1719
TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id)
18-
@invokelatest f(args...; kwargs...)
20+
result[] = @invokelatest f(args...; kwargs...)
21+
return
1922
end
2023
set_task_tid!(task, proc.tid)
2124
schedule(task)
2225
try
2326
fetch(task)
27+
return result[]
2428
catch err
25-
@static if VERSION < v"1.7-rc1"
26-
stk = Base.catch_stack(task)
27-
else
28-
stk = Base.current_exceptions(task)
29-
end
30-
err, frames = stk[1]
29+
err, frames = Base.current_exceptions(task)[1]
3130
rethrow(CapturedException(err, frames))
3231
end
3332
end
3433
get_parent(proc::ThreadProc) = OSProc(proc.owner)
3534
default_enabled(proc::ThreadProc) = true
3635

3736
# TODO: ThreadGroupProc?
38-

Diff for: src/thunk.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ Base.hash(x::Thunk, h::UInt) = hash(x.id, hash(h, 0x7ad3bac49089a05f % UInt))
438438
Base.isequal(x::Thunk, y::Thunk) = x.id==y.id
439439

440440
function show_thunk(io::IO, t)
441-
lvl = get(io, :lazy_level, 2)
441+
lvl = get(io, :lazy_level, 0)
442442
f = if t.f isa Chunk
443443
Tf = t.f.chunktype
444444
if isdefined(Tf, :instance)

0 commit comments

Comments
 (0)