Skip to content

Commit 52abd35

Browse files
authored
Merge pull request #528 from JuliaParallel/jps/darray-0dim
DArray: Improve 0-dim support
2 parents ff89478 + 0388dd6 commit 52abd35

File tree

2 files changed

+55
-15
lines changed

2 files changed

+55
-15
lines changed

Diff for: src/array/darray.jl

+23-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ ArrayDomain(xs::NTuple{N,Base.OneTo}) where N =
2020
ArrayDomain{N,NTuple{N,UnitRange{Int}}}(ntuple(i->UnitRange(xs[i]), N))
2121
ArrayDomain(xs::NTuple{N,Int}) where N =
2222
ArrayDomain{N,NTuple{N,UnitRange{Int}}}(ntuple(i->xs[i]:xs[i], N))
23+
ArrayDomain(::Tuple{}) = ArrayDomain{0,Tuple{}}(())
2324
ArrayDomain(xs...) = ArrayDomain((xs...,))
2425
ArrayDomain(xs::Array) = ArrayDomain((xs...,))
2526

@@ -31,6 +32,7 @@ chunks(a::ArrayDomain{N}) where {N} = DomainBlocks(
3132

3233
(==)(a::ArrayDomain, b::ArrayDomain) = indexes(a) == indexes(b)
3334
Base.getindex(arr::AbstractArray, d::ArrayDomain) = arr[indexes(d)...]
35+
Base.getindex(arr::AbstractArray{T,0} where T, d::ArrayDomain{0}) = arr
3436

3537
function intersect(a::ArrayDomain, b::ArrayDomain)
3638
if a === b
@@ -145,7 +147,6 @@ const WrappedDVector{T} = WrappedDArray{T,1}
145147
const DMatrix{T} = DArray{T,2}
146148
const DVector{T} = DArray{T,1}
147149

148-
149150
# mainly for backwards-compatibility
150151
DArray{T, N}(domain, subdomains, chunks, partitioning, concat=cat) where {T,N} =
151152
DArray(T, domain, subdomains, chunks, partitioning, concat)
@@ -178,6 +179,10 @@ function Base.collect(d::DArray; tree=false)
178179
return Array{eltype(d)}(undef, size(d)...)
179180
end
180181

182+
if ndims(d) == 0
183+
return fetch(a.chunks[1])
184+
end
185+
181186
dimcatfuncs = [(x...) -> d.concat(x..., dims=i) for i in 1:ndims(d)]
182187
if tree
183188
collect(fetch(treereduce_nd(map(x -> ((args...,) -> Dagger.@spawn x(args...)) , dimcatfuncs), a.chunks)))
@@ -214,6 +219,7 @@ else
214219
Base.alignment(io::IO, x::ColorElement) =
215220
Base.alignment(io, something(x.value, "..."))
216221
#end
222+
Base.show(io::IO, x::ColorElement) = show(io, MIME("text/plain"), x)
217223
struct ColorArray{T,N} <: DenseArray{T,N}
218224
A::DArray{T,N}
219225
color_map::Vector{Symbol}
@@ -261,9 +267,21 @@ function Base.getindex(A::ColorArray{T,N}, idxs::Dims{S}) where {T,N,S}
261267
end
262268
end
263269
function Base.show(io::IO, ::MIME"text/plain", A::DArray{T,N}) where {T,N}
264-
write(io, string(DArray{T,N}))
265-
write(io, string(size(A)))
266-
write(io, " with $(join(size(A.chunks), 'x')) partitions of size $(join(A.partitioning.blocksize, 'x')):")
270+
if N == 1
271+
write(io, "$(length(A))-element ")
272+
write(io, string(DVector{T}))
273+
elseif N == 2
274+
write(io, string(DMatrix{T}))
275+
elseif N == 0
276+
write(io, "0-dimensional ")
277+
write(io, "DArray{$T, $N}")
278+
else
279+
write(io, "$(join(size(A), 'x')) ")
280+
write(io, "DArray{$T, $N}")
281+
end
282+
nparts = N > 0 ? size(A.chunks) : 1
283+
partsize = N > 0 ? A.partitioning.blocksize : 1
284+
write(io, " with $(join(nparts, 'x')) partitions of size $(join(partsize, 'x')):")
267285
pct_complete = 100 * (sum(c->c isa Chunk ? true : isready(c), A.chunks) / length(A.chunks))
268286
if pct_complete < 100
269287
println(io)
@@ -472,7 +490,7 @@ struct AutoBlocks end
472490
function auto_blocks(dims::Dims{N}) where N
473491
# TODO: Allow other partitioning schemes
474492
np = num_processors()
475-
p = cld(dims[end], np)
493+
p = N > 0 ? cld(dims[end], np) : 1
476494
return Blocks(ntuple(i->i == N ? p : dims[i], N))
477495
end
478496
auto_blocks(A::AbstractArray{T,N}) where {T,N} = auto_blocks(size(A))

Diff for: test/array/allocation.jl

+32-10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
@testset "DVector/DMatrix/DArray constructor" begin
22
for T in [Float32, Float64, Int32, Int64]
3+
F = fill(one(T))
34
V = rand(T, 64)
45
M = rand(T, 64, 64)
56
A = rand(T, 64, 64, 64)
67

8+
# DArray ctor (empty)
9+
DF = DArray(F, Blocks(()))
10+
@test DF isa DArray{T,0}
11+
@test collect(DF) == F
12+
@test size(DF) == size(F)
13+
714
# DVector ctor
815
DV = DVector(V, Blocks(8))
916
@test DV isa DVector{T}
@@ -26,7 +33,8 @@ end
2633

2734
@testset "random" begin
2835
for T in [Float32, Float64, Int32, Int64]
29-
for dims in [(100,),
36+
for dims in [(),
37+
(100,),
3038
(100, 100),
3139
(100, 100, 100)]
3240
dist = Blocks(ntuple(i->10, length(dims))...)
@@ -53,10 +61,9 @@ end
5361
@test AXn isa Array{T,length(dims)}
5462
@test AXn == collect(Xn)
5563
@test AXn != collect(randn(dist, T, dims...))
56-
@test !all(AXn .> 0)
5764
end
5865

59-
if length(dims) <= 2
66+
if 1 <= length(dims) <= 2
6067
# sprand
6168
Xsp = sprand(dist, T, dims..., 0.1)
6269
@test Xsp isa DArray{T,length(dims)}
@@ -76,7 +83,8 @@ end
7683
@testset "ones/zeros" begin
7784
for T in [Float32, Float64, Int32, Int64]
7885
for (fn, value) in [(ones, one(T)), (zeros, zero(T))]
79-
for dims in [(100,),
86+
for dims in [(),
87+
(100,),
8088
(100, 100),
8189
(100, 100, 100)]
8290
dist = Blocks(ntuple(i->10, length(dims))...)
@@ -119,11 +127,16 @@ end
119127
for i in 1:(length(dims)-1)
120128
@test part_size[i] == 100
121129
end
122-
@test part_size[end] == cld(100, np)
130+
if length(dims) > 0
131+
@test part_size[end] == cld(100, np)
132+
else
133+
@test part_size == ()
134+
end
123135
@test size(DA) == ntuple(i->100, length(dims))
124136
end
125137

126-
for dims in [(100,),
138+
for dims in [(),
139+
(100,),
127140
(100, 100),
128141
(100, 100, 100)]
129142
fn = if length(dims) == 1
@@ -133,15 +146,23 @@ end
133146
else
134147
DArray
135148
end
136-
DA = fn(rand(dims...), AutoBlocks())
149+
if length(dims) > 0
150+
DA = fn(rand(dims...), AutoBlocks())
151+
else
152+
DA = fn(fill(rand()), AutoBlocks())
153+
end
137154
test_auto_blocks(DA, dims)
138155

139-
DA = distribute(rand(dims...), AutoBlocks())
156+
if length(dims) > 0
157+
DA = distribute(rand(dims...), AutoBlocks())
158+
else
159+
DA = distribute(fill(rand()), AutoBlocks())
160+
end
140161
test_auto_blocks(DA, dims)
141162

142163
for fn in [rand, randn, sprand, ones, zeros]
143164
if fn === sprand
144-
if length(dims) > 2
165+
if length(dims) > 2 || length(dims) == 0
145166
continue
146167
end
147168
DA = fn(AutoBlocks(), dims..., 0.1)
@@ -155,7 +176,8 @@ end
155176

156177
@testset "Constructor variants" begin
157178
for fn in [ones, zeros, rand, randn, sprand]
158-
for dims in [(100,),
179+
for dims in [(),
180+
(100,),
159181
(100, 100),
160182
(100, 100, 100)]
161183
for dist in [Blocks(ntuple(i->10, length(dims))...),

0 commit comments

Comments
 (0)