Skip to content

Commit 3c168b2

Browse files
committed
DArray: Indexing improvements and fixes
Adds optional task-local cache for getindex Adds setindex! operation Fixes lookup_parts indexing fallback for linear indexing Adds allowscalar logic to prevent slow paths Adds first/last helpers
1 parent bfc4313 commit 3c168b2

File tree

10 files changed

+209
-83
lines changed

10 files changed

+209
-83
lines changed

Diff for: Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
2020
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2121
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2222
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
23+
TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
2324
TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63"
2425
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2526

@@ -34,6 +35,7 @@ Requires = "1"
3435
ScopedValues = "1.1"
3536
Statistics = "1"
3637
StatsBase = "0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
38+
TaskLocalValues = "0.1"
3739
TimespanLogging = "0.1"
3840
julia = "1.8"
3941

Diff for: docs/src/darray.md

+3-7
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,7 @@ julia> DZ = DY .* 3
134134
Dagger.DArray{Float64, 2, Blocks{2}, typeof(cat)}(100, 100)
135135
```
136136

137-
Now, `DZ` will contain the result of computing `(DX .+ DX) .* 3`. Note that
138-
`DArray` objects are immutable, and operations on them are thus functional
139-
transformations of their input `DArray`.
140-
141-
!!! note
142-
Support for mutation of `DArray`s is planned for a future release
137+
Now, `DZ` will contain the result of computing `(DX .+ DX) .* 3`.
143138

144139
```
145140
julia> Dagger.chunks(DZ)
@@ -208,14 +203,15 @@ julia> collect(DZ)
208203
```
209204

210205
A variety of other operations exist on the `DArray`, and it should generally
211-
behavior otherwise similar to any other `AbstractArray` type. If you find that
206+
behave otherwise similar to any other `AbstractArray` type. If you find that
212207
it's missing an operation that you need, please file an issue!
213208

214209
### Known Supported Operations
215210

216211
This list is not exhaustive, but documents operations which are known to work well with the `DArray`:
217212

218213
From `Base`:
214+
- `getindex`/`setindex!`
219215
- Broadcasting
220216
- `similar`/`copy`/`copyto!`
221217
- `map`/`reduce`/`mapreduce`

Diff for: src/Dagger.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ include("array/copy.jl")
6666
include("file-io.jl")
6767

6868
include("array/operators.jl")
69-
include("array/getindex.jl")
69+
include("array/indexing.jl")
7070
include("array/setindex.jl")
7171
include("array/matrix.jl")
7272
include("array/sparse_partition.jl")

Diff for: src/array/darray.jl

+11-3
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ Base.:(/)(x::DArray{T,N,B,F}, y::U) where {T<:Real,U<:Real,N,B,F} =
234234
A `view` of a `DArray` chunk returns a `DArray` of `Thunk`s.
235235
"""
236236
function Base.view(c::DArray, d)
237-
subchunks, subdomains = lookup_parts(chunks(c), domainchunks(c), d)
237+
subchunks, subdomains = lookup_parts(c, chunks(c), domainchunks(c), d)
238238
d1 = alignfirst(d)
239239
DArray(eltype(c), d1, subdomains, subchunks, c.partitioning, c.concat)
240240
end
@@ -272,7 +272,7 @@ function group_indices(cumlength, idxs::AbstractRange)
272272
end
273273

274274
_cumsum(x::AbstractArray) = length(x) == 0 ? Int[] : cumsum(x)
275-
function lookup_parts(ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}) where N
275+
function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}) where N
276276
groups = map(group_indices, subdmns.cumlength, indexes(d))
277277
sz = map(length, groups)
278278
pieces = Array{Any}(undef, sz)
@@ -284,7 +284,15 @@ function lookup_parts(ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomai
284284
end
285285
out_cumlength = map(g->_cumsum(map(x->length(x[2]), g)), groups)
286286
out_dmn = DomainBlocks(ntuple(x->1,Val(N)), out_cumlength)
287-
pieces, out_dmn
287+
return pieces, out_dmn
288+
end
289+
function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}) where {N,S}
290+
if S != 1
291+
throw(BoundsError(A, d.indexes))
292+
end
293+
inds = CartesianIndices(A)[d.indexes...]
294+
new_d = ntuple(i->first(inds).I[i]:last(inds).I[i], N)
295+
return lookup_parts(A, ps, subdmns, ArrayDomain(new_d))
288296
end
289297

290298
"""

Diff for: src/array/getindex.jl

-42
This file was deleted.

Diff for: src/array/indexing.jl

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
using TaskLocalValues
2+
3+
### getindex
4+
5+
struct GetIndex{T,N} <: ArrayOp{T,N}
6+
input::ArrayOp
7+
idx::Tuple
8+
end
9+
10+
GetIndex(input::ArrayOp, idx::Tuple) =
11+
GetIndex{eltype(input), ndims(input)}(input, idx)
12+
13+
function stage(ctx::Context, gidx::GetIndex)
14+
inp = stage(ctx, gidx.input)
15+
16+
dmn = domain(inp)
17+
idxs = [if isa(gidx.idx[i], Colon)
18+
indexes(dmn)[i]
19+
else
20+
gidx.idx[i]
21+
end for i in 1:length(gidx.idx)]
22+
23+
# Figure out output dimension
24+
view(inp, ArrayDomain(idxs))
25+
end
26+
27+
function size(x::GetIndex)
28+
map(a -> a[2] isa Colon ?
29+
size(x.input, a[1]) : length(a[2]),
30+
enumerate(x.idx)) |> Tuple
31+
end
32+
33+
Base.getindex(c::ArrayOp, idx::ArrayDomain) =
34+
_to_darray(GetIndex(c, indexes(idx)))
35+
Base.getindex(c::ArrayOp, idx...) =
36+
_to_darray(GetIndex(c, idx))
37+
38+
const GETINDEX_CACHE = TaskLocalValue{Dict{Tuple,Any}}(()->Dict{Tuple,Any}())
39+
const GETINDEX_CACHE_SIZE = ScopedValue{Int}(0)
40+
with_index_caching(f, size::Integer=1) = with(f, GETINDEX_CACHE_SIZE=>size)
41+
function Base.getindex(A::DArray{T,N}, idx::NTuple{N,Int}) where {T,N}
42+
# Scalar indexing check
43+
assert_allowscalar()
44+
45+
# Boundscheck
46+
checkbounds(A, idx...)
47+
48+
# Find the associated partition and offset within it
49+
part_idx, offset_idx = partition_for(A, idx)
50+
51+
# If the partition is cached, use that for lookup
52+
cache = GETINDEX_CACHE[]
53+
cache_size = GETINDEX_CACHE_SIZE[]
54+
if cache_size > 0 && haskey(cache, part_idx)
55+
return cache[part_idx][offset_idx...]
56+
end
57+
58+
# Uncached, fetch the partition
59+
part = fetch(A.chunks[part_idx...])
60+
61+
# Insert the partition into the cache
62+
if cache_size > 0
63+
if length(cache) >= cache_size
64+
# Evict a random entry
65+
key = rand(keys(cache))
66+
delete!(cache, key)
67+
end
68+
cache[part_idx] = part
69+
end
70+
71+
# Return the value
72+
return part[offset_idx...]
73+
end
74+
function partition_for(A::DArray, idx::NTuple{N,Int}) where N
75+
part_idx = zeros(Int, N)
76+
offset_idx = zeros(Int, N)
77+
for dim in 1:N
78+
part_idx_slice = @view part_idx[1:(dim-1)]
79+
trailing_idx_slice = ntuple(i->Colon(), N-dim)
80+
sds = @view A.subdomains[part_idx_slice..., :, trailing_idx_slice...]
81+
for (sd_idx, sd) in enumerate(sds)
82+
sd_range = (sd.indexes::NTuple{N,UnitRange{Int}})[dim]
83+
if sd_range.start <= idx[dim] <= sd_range.stop
84+
part_idx[dim] = sd_idx
85+
offset_idx[dim] = idx[dim] - sd_range.start + 1
86+
break
87+
end
88+
end
89+
end
90+
return (part_idx...,), (offset_idx...,)
91+
end
92+
Base.getindex(A::DArray, idx::Integer...) =
93+
getindex(A, idx)
94+
Base.getindex(A::DArray, idx::Integer) =
95+
getindex(A, Base._ind2sub(A, idx))
96+
Base.getindex(A::DArray, idx::CartesianIndex) =
97+
getindex(A, Tuple(idx))
98+
99+
### setindex!
100+
101+
function Base.setindex!(A::DArray{T,N}, value, idx::NTuple{N,Int}) where {T,N}
102+
# Scalar indexing check
103+
assert_allowscalar()
104+
105+
# Boundscheck
106+
checkbounds(A, idx...)
107+
108+
# Find the associated partition and offset within it
109+
part_idx, offset_idx = partition_for(A, idx)
110+
111+
# If the partition is cached, evict it
112+
cache = GETINDEX_CACHE[]
113+
if haskey(cache, part_idx)
114+
delete!(cache, part_idx)
115+
end
116+
117+
# Set the value
118+
part = A.chunks[part_idx...]
119+
space = memory_space(part)
120+
scope = Dagger.scope(worker=root_worker_id(space))
121+
return fetch(Dagger.@spawn scope=scope setindex!(part, value, offset_idx...))
122+
end
123+
Base.setindex!(A::DArray, value, idx::Integer...) =
124+
setindex!(A, value, idx)
125+
Base.setindex!(A::DArray, value, idx::Integer) =
126+
setindex!(A, value, Base._ind2sub(A, idx))
127+
Base.setindex!(A::DArray, value, idx::CartesianIndex) =
128+
setindex!(A, value, Tuple(idx))
129+
130+
### Allow/disallow scalar indexing
131+
132+
const ALLOWSCALAR_TASK = TaskLocalValue{Bool}(()->true)
133+
const ALLOWSCALAR_SCOPE = ScopedValue{Bool}(false)
134+
isallowscalar() = ALLOWSCALAR_TASK[] || ALLOWSCALAR_SCOPE[]
135+
function assert_allowscalar()
136+
if !isallowscalar()
137+
throw(ArgumentError("Scalar indexing is disallowed\nSee `allowscalar` and `allowscalar!` for ways to disable this check, if necessary"))
138+
end
139+
end
140+
"Allow/disallow scalar indexing for the current task."
141+
function allowscalar!(allow::Bool=true)
142+
ALLOWSCALAR_TASK[] = allow
143+
end
144+
"Allow/disallow scalar indexing for the duration of executing `f`."
145+
function allowscalar(f, allow::Bool=true)
146+
old = ALLOWSCALAR_TASK[]
147+
allowscalar!(allow)
148+
try
149+
return with(f, ALLOWSCALAR_SCOPE=>allow)
150+
finally
151+
allowscalar!(old)
152+
end
153+
end

Diff for: src/array/operators.jl

+5
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,8 @@ function stage(ctx::Context, node::MapChunk)
113113
# TODO: Concrete type
114114
DArray(Any, domain(inputs[1]), domainchunks(inputs[1]), thunks)
115115
end
116+
117+
# Basic indexing helpers
118+
119+
Base.first(A::DArray) = A[begin]
120+
Base.last(A::DArray) = A[end]

Diff for: src/lib/domain-blocks.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function getindex(x::DomainBlocks{N}, idx::Int) where N
1717
if N == 1
1818
_getindex(x, (idx,))
1919
else
20-
_getindex(x, ind2sub(x, idx))
20+
_getindex(x, Base._ind2sub(x, idx))
2121
end
2222
end
2323

Diff for: src/memory-spaces.jl

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ abstract type MemorySpace end
33
struct CPURAMMemorySpace <: MemorySpace
44
owner::Int
55
end
6+
root_worker_id(space::CPURAMMemorySpace) = space.owner
67

78
memory_space(x) = CPURAMMemorySpace(myid())
89
function memory_space(x::Chunk)

0 commit comments

Comments
 (0)