Skip to content

Commit f8e2197

Browse files
authored
Merge pull request #402 from JuliaParallel/jps/diffeq-support
DArray: Improve DiffEq support
2 parents 51fdb5e + 7e09407 commit f8e2197

File tree

4 files changed

+56
-2
lines changed

4 files changed

+56
-2
lines changed

Diff for: src/array/alloc.jl

+8
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ Base.zeros(p::Blocks, t::Type, dims::Integer...) = zeros(p, t, dims)
6868
Base.zeros(p::Blocks, dims::Integer...) = zeros(p, Float64, dims)
6969
Base.zeros(p::Blocks, dims::Tuple) = zeros(p, Float64, dims)
7070

71+
function Base.zero(x::DArray{T,N}) where {T,N}
72+
dims = ntuple(i->x.domain.indexes[i].stop, N)
73+
sd = first(x.subdomains)
74+
part_size = ntuple(i->sd.indexes[i].stop, N)
75+
a = zeros(Blocks(part_size...), T, dims)
76+
return cached_stage(Context(global_context()), a)
77+
end
78+
7179
function sprand(p::Blocks, m::Integer, n::Integer, sparsity::Real)
7280
s = rand(UInt)
7381
f = function (idx, t,sz)

Diff for: src/array/darray.jl

+13
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,19 @@ function Base.isequal(x::ArrayOp, y::ArrayOp)
163163
x === y
164164
end
165165

166+
function Base.similar(x::DArray{T,N,F}) where {T,N,F}
167+
alloc(idx, sz) = Array{T,N}(undef, sz)
168+
thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(x.subdomains)]
169+
return DArray{T,N,F}(x.domain, x.subdomains, thunks, x.concat)
170+
end
171+
172+
Base.copy(x::DArray{T,N,F}) where {T,N,F} =
173+
cached_stage(Context(global_context()), map(identity, x))::DArray{T,N,F}
174+
175+
# Because OrdinaryDiffEq uses `Base.promote_op(/, ::DArray, ::Real)`
176+
Base.:(/)(x::DArray{T,N,F}, y::U) where {T<:Real,U<:Real,N,F} =
177+
(x ./ y)::DArray{Base.promote_op(/, T, U),N,F}
178+
166179
"""
167180
view(c::DArray, d)
168181

Diff for: src/array/operators.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import Base: exp, expm1, log, log10, log1p, sqrt, cbrt, exponent,
32
significand, sin, sinpi, cos, cospi, tan, sec, cot, csc,
43
sinh, cosh, tanh, coth, sech, csch,
@@ -51,7 +50,7 @@ BroadcastStyle(::DaggerBroadcastStyle, ::BroadcastStyle) = DaggerBroadcastStyle(
5150
BroadcastStyle(::BroadcastStyle, ::DaggerBroadcastStyle) = DaggerBroadcastStyle()
5251

5352
function Base.copy(b::Broadcast.Broadcasted{<:DaggerBroadcastStyle})
54-
BCast(b)
53+
cached_stage(Context(global_context()), BCast(b))::DArray
5554
end
5655

5756
function stage(ctx::Context, node::BCast)

Diff for: test/array.jl

+34
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,30 @@ end
4747
@test collect(X1) .+ 1 == collect(X2)
4848
end
4949

50+
@testset "copy/similar" begin
51+
X1 = fetch(ones(Blocks(10, 10), 100, 100))
52+
X2 = copy(X1)
53+
X3 = similar(X1)
54+
@test typeof(X1) === typeof(X2) === typeof(X3)
55+
@test collect(X1) == collect(X2)
56+
@test collect(X1) != collect(X3)
57+
end
58+
59+
@testset "DiffEq support" begin
60+
X = fetch(ones(Blocks(10), 100))
61+
X0 = zero(X)
62+
@test typeof(X) === typeof(X0)
63+
@test all(collect(X0) .== 0)
64+
@testset for T in (Int8, Int, Float32, Float64)
65+
DT = DArray{Base.promote_op(/, Float64, T), 1, typeof(cat)}
66+
@test Base.promote_op(/, typeof(X), T) === DT
67+
y = T(2)
68+
Xd = X / y
69+
@test typeof(Xd) === DT
70+
@test collect(Xd) == collect(X) ./ y
71+
end
72+
end
73+
5074
@testset "sum" begin
5175
X = ones(Blocks(10, 10), 100, 100)
5276
@test sum(X) == 10000
@@ -70,6 +94,16 @@ end
7094
@test mean(Y) == 0
7195
end
7296

97+
@testset "broadcast" begin
98+
X1 = fetch(rand(Blocks(10), 100))
99+
X2 = X1 .* 3.4
100+
@test typeof(X1) === typeof(X2)
101+
@test collect(X1) .* 3.4 == collect(X2)
102+
X3 = X1 .+ X1
103+
@test typeof(X1) === typeof(X3)
104+
@test collect(X1) .* 2 == collect(X3)
105+
end
106+
73107
@testset "distributing an array" begin
74108
function test_dist(X)
75109
X1 = Distribute(Blocks(10, 20), X)

0 commit comments

Comments
 (0)