Skip to content

Commit 3ee4091

Browse files
authored
Merge pull request #535 from JuliaParallel/jps/datadeps-gpu
DArray: Make allocations configurable via Processor
2 parents ee6b0ce + c822206 commit 3ee4091

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

Diff for: src/array/alloc.jl

+21-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ export partition
44

55
mutable struct AllocateArray{T,N} <: ArrayOp{T,N}
66
eltype::Type{T}
7-
f::Function
7+
f
8+
want_index::Bool
89
domain::ArrayDomain{N}
910
domainchunks
1011
partitioning::AbstractBlocks
@@ -23,17 +24,29 @@ function partition(p::AbstractBlocks, dom::ArrayDomain)
2324
map(_cumlength, map(length, indexes(dom)), p.blocksize))
2425
end
2526

27+
function allocate_array(f, T, idx, sz)
28+
new_f = allocate_array_func(thunk_processor(), f)
29+
return new_f(idx, T, sz)
30+
end
31+
function allocate_array(f, T, sz)
32+
new_f = allocate_array_func(thunk_processor(), f)
33+
return new_f(T, sz)
34+
end
35+
allocate_array_func(::Processor, f) = f
2636
function stage(ctx, a::AllocateArray)
27-
alloc(idx, sz) = a.f(idx, a.eltype, sz)
28-
thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(a.domainchunks)]
37+
if a.want_index
38+
thunks = [Dagger.@spawn allocate_array(a.f, a.eltype, i, size(x)) for (i, x) in enumerate(a.domainchunks)]
39+
else
40+
thunks = [Dagger.@spawn allocate_array(a.f, a.eltype, size(x)) for (i, x) in enumerate(a.domainchunks)]
41+
end
2942
return DArray(a.eltype, a.domain, a.domainchunks, thunks, a.partitioning)
3043
end
3144

3245
const BlocksOrAuto = Union{Blocks{N} where N, AutoBlocks}
3346

3447
function Base.rand(p::Blocks, eltype::Type, dims::Dims)
3548
d = ArrayDomain(map(x->1:x, dims))
36-
a = AllocateArray(eltype, (_, x...) -> rand(x...), d, partition(p, d), p)
49+
a = AllocateArray(eltype, rand, false, d, partition(p, d), p)
3750
return _to_darray(a)
3851
end
3952
Base.rand(p::BlocksOrAuto, T::Type, dims::Integer...) = rand(p, T, dims)
@@ -45,7 +58,7 @@ Base.rand(::AutoBlocks, eltype::Type, dims::Dims) =
4558

4659
function Base.randn(p::Blocks, eltype::Type, dims::Dims)
4760
d = ArrayDomain(map(x->1:x, dims))
48-
a = AllocateArray(eltype, (_, x...) -> randn(x...), d, partition(p, d), p)
61+
a = AllocateArray(eltype, randn, false, d, partition(p, d), p)
4962
return _to_darray(a)
5063
end
5164
Base.randn(p::BlocksOrAuto, T::Type, dims::Integer...) = randn(p, T, dims)
@@ -57,7 +70,7 @@ Base.randn(::AutoBlocks, eltype::Type, dims::Dims) =
5770

5871
function sprand(p::Blocks, eltype::Type, dims::Dims, sparsity::AbstractFloat)
5972
d = ArrayDomain(map(x->1:x, dims))
60-
a = AllocateArray(eltype, (_, T, _dims) -> sprand(T, _dims..., sparsity), d, partition(p, d), p)
73+
a = AllocateArray(eltype, (T, _dims) -> sprand(T, _dims..., sparsity), false, d, partition(p, d), p)
6174
return _to_darray(a)
6275
end
6376
sprand(p::BlocksOrAuto, T::Type, dims_and_sparsity::Real...) =
@@ -73,7 +86,7 @@ sprand(::AutoBlocks, eltype::Type, dims::Dims, sparsity::AbstractFloat) =
7386

7487
function Base.ones(p::Blocks, eltype::Type, dims::Dims)
7588
d = ArrayDomain(map(x->1:x, dims))
76-
a = AllocateArray(eltype, (_, x...) -> ones(x...), d, partition(p, d), p)
89+
a = AllocateArray(eltype, ones, false, d, partition(p, d), p)
7790
return _to_darray(a)
7891
end
7992
Base.ones(p::BlocksOrAuto, T::Type, dims::Integer...) = ones(p, T, dims)
@@ -85,7 +98,7 @@ Base.ones(::AutoBlocks, eltype::Type, dims::Dims) =
8598

8699
function Base.zeros(p::Blocks, eltype::Type, dims::Dims)
87100
d = ArrayDomain(map(x->1:x, dims))
88-
a = AllocateArray(eltype, (_, x...) -> zeros(x...), d, partition(p, d), p)
101+
a = AllocateArray(eltype, zeros, false, d, partition(p, d), p)
89102
return _to_darray(a)
90103
end
91104
Base.zeros(p::BlocksOrAuto, T::Type, dims::Integer...) = zeros(p, T, dims)

Diff for: src/array/darray.jl

+3-7
Original file line numberDiff line numberDiff line change
@@ -306,16 +306,12 @@ function Base.isequal(x::ArrayOp, y::ArrayOp)
306306
x === y
307307
end
308308

309-
function Base.similar(x::DArray{T,N}) where {T,N}
310-
alloc(idx, sz) = Array{T,N}(undef, sz)
311-
thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(x.subdomains)]
312-
return DArray(T, x.domain, x.subdomains, thunks, x.partitioning, x.concat)
313-
end
314-
309+
struct AllocateUndef{S} end
310+
(::AllocateUndef{S})(T, dims::Dims{N}) where {S,N} = Array{S,N}(undef, dims)
315311
function Base.similar(A::DArray{T,N} where T, ::Type{S}, dims::Dims{N}) where {S,N}
316312
d = ArrayDomain(map(x->1:x, dims))
317313
p = A.partitioning
318-
a = AllocateArray(S, (_, _, x...) -> Array{S,N}(undef, x...), d, partition(p, d), p)
314+
a = AllocateArray(S, AllocateUndef{S}(), false, d, partition(p, d), p)
319315
return _to_darray(a)
320316
end
321317

Diff for: src/array/mul.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ function gemm_dagger!(
109109
Bmt, Bnt = size(Bc)
110110
Cmt, Cnt = size(Cc)
111111

112-
alpha = _add.alpha
113-
beta = _add.beta
112+
alpha = T(_add.alpha)
113+
beta = T(_add.beta)
114114

115115
if Ant != Bmt
116116
throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but B has number of blocks ($Bmt,$Bnt)"))
@@ -212,8 +212,8 @@ function syrk_dagger!(
212212
Amt, Ant = size(Ac)
213213
Cmt, Cnt = size(Cc)
214214

215-
alpha = _add.alpha
216-
beta = _add.beta
215+
alpha = T(_add.alpha)
216+
beta = T(_add.beta)
217217

218218
uplo = 'U'
219219
if Ant != Cmt
@@ -233,7 +233,7 @@ function syrk_dagger!(
233233
Dagger.@spawn BLAS.herk!(
234234
uplo,
235235
trans,
236-
alpha,
236+
real(alpha),
237237
In(Ac[n, k]),
238238
mzone,
239239
InOut(Cc[n, n]),

0 commit comments

Comments
 (0)