Skip to content

Commit 7b1ccab

Browse files
jpsamarooRabab53
andcommitted
DArray: Add cholesky implementation
Co-authored-by: rabab53 <[email protected]>
1 parent 75a931e commit 7b1ccab

File tree

4 files changed

+118
-0
lines changed

4 files changed

+118
-0
lines changed

Diff for: src/Dagger.jl

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ include("array/setindex.jl")
6060
include("array/matrix.jl")
6161
include("array/sparse_partition.jl")
6262
include("array/sort.jl")
63+
include("array/cholesky.jl")
6364

6465
# Other
6566
include("ui/graph-core.jl")

Diff for: src/array/cholesky.jl

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
is_cross_symmetric(A1, A2) = A1 == A2'
2+
function LinearAlgebra.issymmetric(A::DArray{T,2}) where T
3+
Ac = A.chunks
4+
if size(Ac, 1) != size(Ac, 2)
5+
return false
6+
end
7+
8+
to_check = [Dagger.@spawn issymmetric(Ac[i, i]) for i in 1:size(Ac, 1)]
9+
for i in 2:(size(Ac, 1)-1)
10+
j_pre_diag = i - 1
11+
for j in 1:j_pre_diag
12+
push!(to_check, Dagger.@spawn is_cross_symmetric(Ac[i, j], Ac[j, i]))
13+
end
14+
end
15+
16+
return all(fetch, to_check)
17+
end
18+
function LinearAlgebra.ishermitian(A::DArray{T,2}) where T
19+
Ac = A.chunks
20+
if size(Ac, 1) != size(Ac, 2)
21+
return false
22+
end
23+
24+
to_check = [Dagger.@spawn ishermitian(Ac[i, i]) for i in 1:size(Ac, 1)]
25+
for i in 2:(size(Ac, 1)-1)
26+
j_pre_diag = i - 1
27+
for j in 1:j_pre_diag
28+
push!(to_check, Dagger.@spawn is_cross_symmetric(Ac[i, j], Ac[j, i]))
29+
end
30+
end
31+
32+
return all(fetch, to_check)
33+
end
34+
35+
LinearAlgebra.cholcopy(A::DArray{T,2}) where T = copy(A)
36+
function LinearAlgebra.cholesky!(A::DArray{T,2}, ::NoPivot; check::Bool=true) where T
37+
if check
38+
ishermitian(A) || throw(PosDefException(-1))
39+
end
40+
41+
zone = one(T)
42+
mzone = -one(T)
43+
uplo = 'U'
44+
Ac = A.chunks
45+
mt, nt = size(Ac)
46+
47+
if uplo == 'L'
48+
Dagger.spawn_datadeps() do
49+
for k in range(1, mt)
50+
Dagger.@spawn LAPACK.potrf!(uplo, InOut(Ac[k, k]))
51+
for m in range(k+1, mt)
52+
Dagger.@spawn BLAS.trsm!('R', 'L', 'T', 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]))
53+
end
54+
for n in range(k+1, nt)
55+
Dagger.@spawn BLAS.syrk!(uplo, 'N', mzone, In(Ac[n, k]), zone, InOut(Ac[n, n]))
56+
for m in range(n+1, mt)
57+
Dagger.@spawn BLAS.gemm!('N', 'T', mzone, In(Ac[m, k]), In(Ac[n, k]), zone, InOut(Ac[m, n]))
58+
end
59+
end
60+
end
61+
end
62+
63+
return Cholesky(A, 'L', 0)
64+
elseif uplo == 'U'
65+
Dagger.spawn_datadeps() do
66+
for k in range(1, mt)
67+
Dagger.@spawn LAPACK.potrf!(uplo, InOut(Ac[k, k]))
68+
for n in range(k+1, nt)
69+
Dagger.@spawn BLAS.trsm!('L', uplo, 'T', 'N', zone, In(Ac[k, k]), InOut(Ac[k, n]))
70+
end
71+
for m in range(k+1, mt)
72+
Dagger.@spawn BLAS.syrk!(uplo, 'T', mzone, In(Ac[k, m]), zone, InOut(Ac[m, m]))
73+
for n in range(m+1, nt)
74+
Dagger.@spawn BLAS.gemm!('T', 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
75+
end
76+
end
77+
end
78+
end
79+
80+
return Cholesky(A, 'U', 0)
81+
end
82+
end

Diff for: test/linalg.jl

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
@testset "Linear Algebra" begin
2+
@testset "Cholesky: $T" for T in (Float32, Float64, ComplexF64)
3+
D = rand(Blocks(4, 4), T, 32, 32)
4+
if !(T <: Complex)
5+
@test !issymmetric(D)
6+
end
7+
@test !ishermitian(D)
8+
9+
A = rand(T, 128, 128)
10+
A = A * A'
11+
DA = view(A, Blocks(32, 32))
12+
if !(T <: Complex)
13+
@test issymmetric(DA)
14+
end
15+
@test ishermitian(DA)
16+
17+
# Out-of-place
18+
chol_A = cholesky(A)
19+
chol_DA = cholesky(DA)
20+
@test chol_DA isa Cholesky
21+
@test chol_A.L chol_DA.L
22+
@test chol_A.U chol_DA.U
23+
24+
# In-place
25+
A_copy = copy(A)
26+
chol_A = cholesky!(A_copy)
27+
chol_DA = cholesky!(DA)
28+
@test chol_DA isa Cholesky
29+
@test chol_A.L chol_DA.L
30+
@test chol_A.U chol_DA.U
31+
# Check that changes propagated to A
32+
@test UpperTriangular(collect(DA)) UpperTriangular(collect(A))
33+
end
34+
end

Diff for: test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ include("task-queues.jl")
3838
include("datadeps.jl")
3939
include("domain.jl")
4040
include("array.jl")
41+
include("linalg.jl")
4142
include("cache.jl")
4243
include("diskcaching.jl")
4344
include("file-io.jl")

0 commit comments

Comments
 (0)