Skip to content

Commit 081200c

Browse files
authored
Merge pull request #469 from JuliaParallel/jps/darray-datadeps
DArray: Add Cholesky and view implementations
2 parents d0cf5cc + b44dce4 commit 081200c

File tree

7 files changed

+221
-0
lines changed

7 files changed

+221
-0
lines changed

Diff for: src/Dagger.jl

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ include("array/setindex.jl")
6060
include("array/matrix.jl")
6161
include("array/sparse_partition.jl")
6262
include("array/sort.jl")
63+
include("array/linalg.jl")
64+
include("array/cholesky.jl")
6365

6466
# Other
6567
include("ui/graph-core.jl")

Diff for: src/array/alloc.jl

+9
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ function Base.zero(x::DArray{T,N}) where {T,N}
8181
return _to_darray(a)
8282
end
8383

84+
function Base.view(A::AbstractArray{T,N}, p::Blocks{N}) where {T,N}
85+
d = ArrayDomain(Base.index_shape(A))
86+
dc = partition(p, d)
87+
# N.B. We use `tochunk` because we only want to take the view locally, and
88+
# taking views should be very fast
89+
chunks = [tochunk(view(A, x.indexes...)) for x in dc]
90+
return DArray(T, d, dc, chunks, p)
91+
end
92+
8493
function sprand(p::Blocks, m::Integer, n::Integer, sparsity::Real)
8594
s = rand(UInt)
8695
f = function (idx, t,sz)

Diff for: src/array/cholesky.jl

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
LinearAlgebra.cholcopy(A::DArray{T,2}) where T = copy(A)
2+
function potrf_checked!(uplo, A, info_arr)
3+
_A, info = LAPACK.potrf!(uplo, A)
4+
if info > 0
5+
info_arr[1] = info
6+
throw(PosDefException(info))
7+
end
8+
return _A, info
9+
end
10+
function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{UpperTriangular}) where T
11+
LinearAlgebra.checksquare(A)
12+
13+
zone = one(T)
14+
mzone = -one(T)
15+
rzone = one(real(T))
16+
rmzone = -one(real(T))
17+
uplo = 'U'
18+
Ac = A.chunks
19+
mt, nt = size(Ac)
20+
iscomplex = T <: Complex
21+
trans = iscomplex ? 'C' : 'T'
22+
23+
info = [convert(LinearAlgebra.BlasInt, 0)]
24+
try
25+
Dagger.spawn_datadeps() do
26+
for k in range(1, mt)
27+
Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info))
28+
for n in range(k+1, nt)
29+
Dagger.@spawn BLAS.trsm!('L', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[k, n]))
30+
end
31+
for m in range(k+1, mt)
32+
if iscomplex
33+
Dagger.@spawn BLAS.herk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
34+
else
35+
Dagger.@spawn BLAS.syrk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
36+
end
37+
for n in range(m+1, nt)
38+
Dagger.@spawn BLAS.gemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
39+
end
40+
end
41+
end
42+
end
43+
catch err
44+
err isa ThunkFailedException || rethrow()
45+
err = Dagger.Sch.unwrap_nested_exception(err.ex)
46+
err isa PosDefException || rethrow()
47+
end
48+
49+
return UpperTriangular(A), info[1]
50+
end
51+
function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{LowerTriangular}) where T
52+
LinearAlgebra.checksquare(A)
53+
54+
zone = one(T)
55+
mzone = -one(T)
56+
rzone = one(real(T))
57+
rmzone = -one(real(T))
58+
uplo = 'L'
59+
Ac = A.chunks
60+
mt, nt = size(Ac)
61+
iscomplex = T <: Complex
62+
trans = iscomplex ? 'C' : 'T'
63+
64+
info = [convert(LinearAlgebra.BlasInt, 0)]
65+
try
66+
Dagger.spawn_datadeps() do
67+
for k in range(1, mt)
68+
Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info))
69+
for m in range(k+1, mt)
70+
Dagger.@spawn BLAS.trsm!('R', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]))
71+
end
72+
for n in range(k+1, nt)
73+
if iscomplex
74+
Dagger.@spawn BLAS.herk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]))
75+
else
76+
Dagger.@spawn BLAS.syrk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]))
77+
end
78+
for m in range(n+1, mt)
79+
Dagger.@spawn BLAS.gemm!('N', trans, mzone, In(Ac[m, k]), In(Ac[n, k]), zone, InOut(Ac[m, n]))
80+
end
81+
end
82+
end
83+
end
84+
catch err
85+
err isa ThunkFailedException || rethrow()
86+
err = Dagger.Sch.unwrap_nested_exception(err.ex)
87+
err isa PosDefException || rethrow()
88+
end
89+
90+
return LowerTriangular(A), info[1]
91+
end

Diff for: src/array/linalg.jl

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
function LinearAlgebra.norm2(A::DArray{T,2}) where T
2+
Ac = A.chunks
3+
norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac]
4+
return sqrt(sum(map(fetch, norms)))
5+
end
6+
function LinearAlgebra.norm2(A::UpperTriangular{T,<:DArray{T,2}}) where T
7+
Ac = parent(A).chunks
8+
Ac_upper = []
9+
for i in 1:size(Ac, 1)
10+
append!(Ac_upper, Ac[i, (i+1):end])
11+
end
12+
upper_norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac_upper]
13+
Ac_diag = [Dagger.spawn(UpperTriangular, Ac[i,i]) for i in 1:size(Ac, 1)]
14+
diag_norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac_diag]
15+
return sqrt(sum(map(fetch, upper_norms)) + sum(map(fetch, diag_norms)))
16+
end
17+
function LinearAlgebra.norm2(A::LowerTriangular{T,<:DArray{T,2}}) where T
18+
Ac = parent(A).chunks
19+
Ac_lower = []
20+
for i in 1:size(Ac, 1)
21+
append!(Ac_lower, Ac[(i+1):end, i])
22+
end
23+
lower_norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac_lower]
24+
Ac_diag = [Dagger.spawn(LowerTriangular, Ac[i,i]) for i in 1:size(Ac, 1)]
25+
diag_norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac_diag]
26+
return sqrt(sum(map(fetch, lower_norms)) + sum(map(fetch, diag_norms)))
27+
end
28+
29+
is_cross_symmetric(A1, A2) = A1 == A2'
30+
function LinearAlgebra.issymmetric(A::DArray{T,2}) where T
31+
Ac = A.chunks
32+
if size(Ac, 1) != size(Ac, 2)
33+
return false
34+
end
35+
36+
to_check = [Dagger.@spawn issymmetric(Ac[i, i]) for i in 1:size(Ac, 1)]
37+
for i in 2:(size(Ac, 1)-1)
38+
j_pre_diag = i - 1
39+
for j in 1:j_pre_diag
40+
push!(to_check, Dagger.@spawn is_cross_symmetric(Ac[i, j], Ac[j, i]))
41+
end
42+
end
43+
44+
return all(fetch, to_check)
45+
end
46+
function LinearAlgebra.ishermitian(A::DArray{T,2}) where T
47+
Ac = A.chunks
48+
if size(Ac, 1) != size(Ac, 2)
49+
return false
50+
end
51+
52+
to_check = [Dagger.@spawn ishermitian(Ac[i, i]) for i in 1:size(Ac, 1)]
53+
for i in 2:(size(Ac, 1)-1)
54+
j_pre_diag = i - 1
55+
for j in 1:j_pre_diag
56+
push!(to_check, Dagger.@spawn is_cross_symmetric(Ac[i, j], Ac[j, i]))
57+
end
58+
end
59+
60+
return all(fetch, to_check)
61+
end

Diff for: test/array.jl

+10
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ end
3838
@test r[1:10] != r[11:20]
3939
end
4040

41+
@testset "view" begin
42+
A = rand(64, 64)
43+
DA = view(A, Blocks(8, 8))
44+
@test collect(DA) == A
45+
@test size(DA) == (64, 64)
46+
A_v = fetch(first(DA.chunks))
47+
@test A_v isa SubArray
48+
@test A_v == A[1:8, 1:8]
49+
end
50+
4151
@testset "map" begin
4252
X1 = ones(Blocks(10, 10), 100, 100)
4353
X2 = map(x->x+1, X1)

Diff for: test/linalg.jl

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
@testset "Linear Algebra" begin
2+
@testset "Cholesky: $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
3+
D = rand(Blocks(4, 4), T, 32, 32)
4+
if !(T <: Complex)
5+
@test !issymmetric(D)
6+
end
7+
@test !ishermitian(D)
8+
9+
A = rand(T, 128, 128)
10+
A = A * A'
11+
A[diagind(A)] .+= size(A, 1)
12+
DA = view(A, Blocks(32, 32))
13+
if !(T <: Complex)
14+
@test issymmetric(DA)
15+
end
16+
@test ishermitian(DA)
17+
18+
# Out-of-place
19+
chol_A = cholesky(A)
20+
chol_DA = cholesky(DA)
21+
@test chol_DA isa Cholesky
22+
@test chol_A.L chol_DA.L
23+
@test chol_A.U chol_DA.U
24+
25+
# In-place
26+
A_copy = copy(A)
27+
chol_A = cholesky!(A_copy)
28+
chol_DA = cholesky!(DA)
29+
@test chol_DA isa Cholesky
30+
@test chol_A.L chol_DA.L
31+
@test chol_A.U chol_DA.U
32+
# Check that changes propagated to A
33+
@test UpperTriangular(collect(DA)) UpperTriangular(collect(A))
34+
35+
# Non-PosDef matrix
36+
A = rand(T, 128, 128)
37+
A = A * A'
38+
A[diagind(A)] .+= size(A, 1)
39+
A[1, 1] = -100
40+
DA = view(A, Blocks(32, 32))
41+
if !(T <: Complex)
42+
@test issymmetric(DA)
43+
end
44+
@test ishermitian(DA)
45+
@test_throws_unwrap PosDefException cholesky(DA).U
46+
end
47+
end

Diff for: test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ include("task-queues.jl")
3838
include("datadeps.jl")
3939
include("domain.jl")
4040
include("array.jl")
41+
include("linalg.jl")
4142
include("cache.jl")
4243
include("diskcaching.jl")
4344
include("file-io.jl")

0 commit comments

Comments
 (0)