Skip to content

Commit 33dca52

Browse files
Check that eltype from each shard is consistent (#170)
Co-authored-by: Rami <[email protected]> Co-authored-by: Valentin Churavy <[email protected]>
1 parent 901b14e commit 33dca52

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

src/darray.jl

+13-5
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,31 @@ Base.hash(d::DArray, h::UInt) = Base.hash(d.id, h)
7878
## core constructors ##
7979

8080
function DArray(id, init, dims, pids, idxs, cuts)
81-
r=Channel(1)
81+
localtypes = Vector{DataType}(undef,length(pids))
82+
8283
@sync begin
8384
for i = 1:length(pids)
8485
@async begin
8586
local typA
8687
if isa(init, Function)
87-
typA=remotecall_fetch(construct_localparts, pids[i], init, id, dims, pids, idxs, cuts)
88+
typA = remotecall_fetch(construct_localparts, pids[i], init, id, dims, pids, idxs, cuts)
8889
else
8990
# constructing from an array of remote refs.
90-
typA=remotecall_fetch(construct_localparts, pids[i], init[i], id, dims, pids, idxs, cuts)
91+
typA = remotecall_fetch(construct_localparts, pids[i], init[i], id, dims, pids, idxs, cuts)
9192
end
92-
!isready(r) && put!(r, typA)
93+
localtypes[i] = typA
9394
end
9495
end
9596
end
9697

97-
A = take!(r)
98+
if length(unique(localtypes)) != 1
99+
@sync for p in pids
100+
@async remotecall_fetch(release_localpart, p, id)
101+
end
102+
throw(ErrorException("Constructed localparts have different `eltype`: $(localtypes)"))
103+
end
104+
A = first(localtypes)
105+
98106
if myid() in pids
99107
d = registry[id]
100108
d = isa(d, WeakRef) ? d.value : d

test/darray.jl

+11-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,17 @@ using Random
6666
@test DistributedArrays.defaultdist(50,4) == [1,14,27,39,51]
6767
end
6868

69-
69+
@testset "Inhomogenous typeof(localpart)" begin
70+
block = 10
71+
Y = nworkers() * block
72+
X = nworkers() * block
73+
74+
@assert nworkers() > 1
75+
@test_throws ErrorException DArray((X, Y)) do I
76+
eltype = first(CartesianIndices(I)) == CartesianIndex(1, 1) ? Int64 : Float64
77+
zeros(eltype, map(length, I))
78+
end
79+
end
7080
end
7181

7282
check_leaks()

0 commit comments

Comments
 (0)