Skip to content

Commit 0aad01f

Browse files
committed
DArray: Remove stage caching
1 parent 6bc14e5 commit 0aad01f

File tree

8 files changed

+31
-61
lines changed

8 files changed

+31
-61
lines changed

Diff for: src/array/darray.jl

+4-34
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian()
8080

8181
collect(x::ArrayOp) = collect(fetch(x))
8282

83-
_to_darray(x::ArrayOp) = cached_stage(Context(global_context()), x)::DArray
83+
_to_darray(x::ArrayOp) = stage(Context(global_context()), x)::DArray
8484
Base.fetch(x::ArrayOp) = fetch(_to_darray(x))
8585

8686
collect(x::Computation) = collect(fetch(x))
8787

88-
Base.fetch(x::Computation) = fetch(cached_stage(Context(global_context()), x))
88+
Base.fetch(x::Computation) = fetch(stage(Context(global_context()), x))
8989

9090
function Base.show(io::IO, ::MIME"text/plain", x::ArrayOp)
9191
write(io, string(typeof(x)))
@@ -288,36 +288,6 @@ function Base.fetch(c::DArray{T}) where T
288288
end
289289
end
290290

291-
global _stage_cache = WeakKeyDict{Context, Dict}()
292-
293-
"""
294-
cached_stage(ctx::Context, x)
295-
296-
A memoized version of stage. It is important that the
297-
tasks generated for the same `DArray` have the same
298-
identity, for example:
299-
300-
```julia
301-
A = rand(Blocks(100,100), Float64, 1000, 1000)
302-
compute(A+A')
303-
```
304-
305-
must not result in computation of `A` twice.
306-
"""
307-
function cached_stage(ctx::Context, x)
308-
cache = if !haskey(_stage_cache, ctx)
309-
_stage_cache[ctx] = Dict()
310-
else
311-
_stage_cache[ctx]
312-
end
313-
314-
if haskey(cache, x)
315-
cache[x]
316-
else
317-
cache[x] = stage(ctx, x)
318-
end
319-
end
320-
321291
Base.@deprecate_binding Cat DArray
322292
Base.@deprecate_binding ComputedArray DArray
323293

@@ -352,15 +322,15 @@ end
352322
function stage(ctx::Context, d::Distribute)
353323
if isa(d.data, ArrayOp)
354324
# distributing a distributed array
355-
x = cached_stage(ctx, d.data)
325+
x = stage(ctx, d.data)
356326
if d.domainchunks == domainchunks(x)
357327
return x # already properly distributed
358328
end
359329
Nd = ndims(x)
360330
T = eltype(d.data)
361331
concat = x.concat
362332
cs = map(d.domainchunks) do idx
363-
chunks = cached_stage(ctx, x[idx]).chunks
333+
chunks = stage(ctx, x[idx]).chunks
364334
shape = size(chunks)
365335
# TODO: fix hashing
366336
#hash = uhash(idx, Base.hash(Distribute, Base.hash(d.data)))

Diff for: src/array/getindex.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ GetIndex(input::ArrayOp, idx::Tuple) =
77
GetIndex{eltype(input), ndims(input)}(input, idx)
88

99
function stage(ctx::Context, gidx::GetIndex)
10-
inp = cached_stage(ctx, gidx.input)
10+
inp = stage(ctx, gidx.input)
1111

1212
dmn = domain(inp)
1313
idxs = [if isa(gidx.idx[i], Colon)
@@ -32,7 +32,7 @@ struct GetIndexScalar <: Computation
3232
end
3333

3434
function stage(ctx::Context, gidx::GetIndexScalar)
35-
inp = cached_stage(ctx, gidx.input)
35+
inp = stage(ctx, gidx.input)
3636
s = view(inp, ArrayDomain(gidx.idx))
3737
Dagger.@spawn identity(collect(s)[1])
3838
end

Diff for: src/array/map-reduce.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ size(x::Map) = size(x.inputs[1])
1010
Map(f, inputs::Tuple) = Map{Any, ndims(inputs[1])}(f, inputs)
1111

1212
function stage(ctx::Context, node::Map)
13-
inputs = Any[cached_stage(ctx, n) for n in node.inputs]
13+
inputs = Any[stage(ctx, n) for n in node.inputs]
1414
primary = inputs[1] # all others will align to this guy
1515
domains = domainchunks(primary)
1616
thunks = similar(domains, Any)
@@ -130,7 +130,7 @@ function Base.reduce(f::Function, x::ArrayOp; dims = nothing)
130130
end
131131

132132
function stage(ctx::Context, r::Reducedim)
133-
inp = cached_stage(ctx, r.input)
133+
inp = stage(ctx, r.input)
134134
thunks = let op = r.op, dims=r.dims
135135
# do reducedim on each block
136136
tmp = map(p->Dagger.spawn(b->reduce(op,b,dims=dims), p), chunks(inp))

Diff for: src/array/matrix.jl

+12-12
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function _ctranspose(x::AbstractArray)
4545
end
4646

4747
function stage(ctx::Context, node::Transpose)
48-
inp = cached_stage(ctx, node.input)
48+
inp = stage(ctx, node.input)
4949
thunks = _ctranspose(chunks(inp))
5050
return DArray(eltype(inp), domain(inp)', domainchunks(inp)', thunks, inp.partitioning', inp.concat)
5151
end
@@ -143,7 +143,7 @@ function promote_distribution(ctx::Context, m::MatMul, a,b)
143143
pb = domainchunks(b)
144144

145145
d = DomainBlocks((1,1), (pa.cumlength[2], pb.cumlength[2])) # FIXME: this is not generic
146-
a, cached_stage(ctx, Distribute(d, b))
146+
a, stage(ctx, Distribute(d, b))
147147
end
148148

149149
function stage_operands(ctx::Context, m::MatMul, a, b)
@@ -152,14 +152,14 @@ function stage_operands(ctx::Context, m::MatMul, a, b)
152152
end
153153
# take the row distribution of a and get b onto that.
154154

155-
stg_a = cached_stage(ctx, a)
156-
stg_b = cached_stage(ctx, b)
155+
stg_a = stage(ctx, a)
156+
stg_b = stage(ctx, b)
157157
promote_distribution(ctx, m, stg_a, stg_b)
158158
end
159159

160160
"An operand which should be distributed as per convenience"
161161
function stage_operands(ctx::Context, ::MatMul, a::ArrayOp, b::PromotePartition{T,1}) where T
162-
stg_a = cached_stage(ctx, a)
162+
stg_a = stage(ctx, a)
163163
dmn_a = domain(stg_a)
164164
dchunks_a = domainchunks(stg_a)
165165
dmn_b = domain(b.data)
@@ -168,19 +168,19 @@ function stage_operands(ctx::Context, ::MatMul, a::ArrayOp, b::PromotePartition{
168168
end
169169
dmn_out = DomainBlocks((1,),(dchunks_a.cumlength[2],))
170170

171-
stg_a, cached_stage(ctx, Distribute(dmn_out, b.data))
171+
stg_a, stage(ctx, Distribute(dmn_out, b.data))
172172
end
173173

174174
function stage_operands(ctx::Context, ::MatMul, a::PromotePartition, b::ArrayOp)
175175

176176
if size(a, 2) != size(b, 1)
177177
throw(DimensionMismatch("Cannot promote array of domain $(dmn_b) to multiply with an array of size $(dmn_a)"))
178178
end
179-
stg_b = cached_stage(ctx, b)
179+
stg_b = stage(ctx, b)
180180

181181
ps = domainchunks(stg_b)
182182
dmn_out = DomainBlocks((1,1),([size(a.data, 1)], ps.cumlength[1],))
183-
cached_stage(ctx, Distribute(dmn_out, a.data)), stg_b
183+
stage(ctx, Distribute(dmn_out, a.data)), stg_b
184184
end
185185

186186
function stage(ctx::Context, mul::MatMul{T,N}) where {T,N}
@@ -215,11 +215,11 @@ scale(l::ArrayOp, r::ArrayOp) = _to_darray(Scale(l, r))
215215
function stage_operand(ctx::Context, ::Scale, a, b::PromotePartition)
216216
ps = domainchunks(a)
217217
b_parts = DomainBlocks((1,), (ps.cumlength[1],))
218-
cached_stage(ctx, Distribute(b_parts, b.data))
218+
stage(ctx, Distribute(b_parts, b.data))
219219
end
220220

221221
function stage_operand(ctx::Context, ::Scale, a, b)
222-
cached_stage(ctx, b)
222+
stage(ctx, b)
223223
end
224224

225225
function _scale(l, r)
@@ -231,7 +231,7 @@ function _scale(l, r)
231231
end
232232

233233
function stage(ctx::Context, scal::Scale)
234-
r = cached_stage(ctx, scal.r)
234+
r = stage(ctx, scal.r)
235235
l = stage_operand(ctx, scal, r, scal.l)
236236

237237
@assert size(domain(r), 1) == size(domain(l), 1)
@@ -265,7 +265,7 @@ function Base.cat(d::ArrayDomain, ds::ArrayDomain...; dims::Int)
265265
end
266266

267267
function stage(ctx::Context, c::Concat)
268-
inp = Any[cached_stage(ctx, x) for x in c.inputs]
268+
inp = Any[stage(ctx, x) for x in c.inputs]
269269

270270
dmns = map(domain, inp)
271271
dims = [[i == c.axis ? 0 : i for i in size(d)] for d in dmns]

Diff for: src/array/operators.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ BCast(b::Broadcasted) = BCast{typeof(b), combine_eltypes(b.f, b.args), length(ax
2828
size(x::BCast) = map(length, axes(x.bcasted))
2929

3030
function stage_operands(ctx::Context, ::BCast, xs::ArrayOp...)
31-
map(x->cached_stage(ctx, x), xs)
31+
map(x->stage(ctx, x), xs)
3232
end
3333

3434
function stage_operands(ctx::Context, ::BCast, x::ArrayOp, y::PromotePartition)
35-
stg_x = cached_stage(ctx, x)
35+
stg_x = stage(ctx, x)
3636
y1 = Distribute(domain(stg_x), y.data)
37-
stg_x, cached_stage(ctx, y1)
37+
stg_x, stage(ctx, y1)
3838
end
3939

4040
function stage_operands(ctx::Context, ::BCast, x::PromotePartition, y::ArrayOp)
41-
stg_y = cached_stage(ctx, y)
41+
stg_y = stage(ctx, y)
4242
x1 = Distribute(domain(stg_y), x.data)
43-
cached_stage(ctx, x1), stg_y
43+
stage(ctx, x1), stg_y
4444
end
4545

4646
struct DaggerBroadcastStyle <: BroadcastStyle end
@@ -57,7 +57,7 @@ function stage(ctx::Context, node::BCast{B,T,N}) where {B,T,N}
5757
bc = Broadcast.flatten(node.bcasted)
5858
args = bc.args
5959
args1 = map(args) do x
60-
x isa ArrayOp ? cached_stage(ctx, x) : x
60+
x isa ArrayOp ? stage(ctx, x) : x
6161
end
6262
ds = map(x->x isa DArray ? domainchunks(x) : nothing, args1)
6363
sz = size(node)
@@ -84,7 +84,7 @@ function stage(ctx::Context, node::BCast{B,T,N}) where {B,T,N}
8484
end
8585
end |> Tuple
8686
dmn = DomainBlocks(ntuple(_->1, length(s)), splits)
87-
cached_stage(ctx, Distribute(dmn, part, arg)).chunks
87+
stage(ctx, Distribute(dmn, part, arg)).chunks
8888
else
8989
arg
9090
end
@@ -105,7 +105,7 @@ end
105105
mapchunk(f::Function, xs::ArrayOp...) = MapChunk(f, xs)
106106
Base.@deprecate mappart(args...) mapchunk(args...)
107107
function stage(ctx::Context, node::MapChunk)
108-
inputs = map(x->cached_stage(ctx, x), node.input)
108+
inputs = map(x->stage(ctx, x), node.input)
109109
thunks = map(map(chunks, inputs)...) do ps...
110110
Dagger.spawn(node.f, map(p->nothing=>p, ps)...)
111111
end

Diff for: src/array/setindex.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function setindex(x::ArrayOp, val, idxs...)
1212
end
1313

1414
function stage(ctx::Context, sidx::SetIndex)
15-
inp = cached_stage(ctx, sidx.input)
15+
inp = stage(ctx, sidx.input)
1616

1717
dmn = domain(inp)
1818
idxs = [if isa(sidx.idx[i], Colon)

Diff for: src/compute.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export stage, cached_stage, compute, debug_compute, cleanup
1+
export compute, debug_compute
22

33
###### Scheduler #######
44

Diff for: src/file-io.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ function save(p::Computation, name::AbstractString)
335335
end
336336

337337
function stage(ctx::Context, s::Save)
338-
x = cached_stage(ctx, s.input)
338+
x = stage(ctx, s.input)
339339
dir_path = s.name * "_data"
340340
if !isdir(dir_path)
341341
mkdir(dir_path)

0 commit comments

Comments
 (0)