Skip to content

Commit d77ab1b

Browse files
committed
fixup! fixup! fixup! DArray: Add cholesky implementation
1 parent 7f915b9 commit d77ab1b

File tree

2 files changed

+68
-25
lines changed

2 files changed

+68
-25
lines changed

Diff for: src/array/cholesky.jl

+55-25
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
LinearAlgebra.cholcopy(A::DArray{T,2}) where T = copy(A)
2+
function potrf_checked!(uplo, A, info_arr)
3+
_A, info = LAPACK.potrf!(uplo, A)
4+
if info > 0
5+
info_arr[1] = info
6+
throw(PosDefException(info))
7+
end
8+
return _A, info
9+
end
210
function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{UpperTriangular}) where T
311
LinearAlgebra.checksquare(A)
412

@@ -12,50 +20,72 @@ function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{UpperTriangular}) where T
1220
iscomplex = T <: Complex
1321
trans = iscomplex ? 'C' : 'T'
1422

15-
Dagger.spawn_datadeps() do
16-
for k in range(1, mt)
17-
Dagger.@spawn LAPACK.potrf!(uplo, InOut(Ac[k, k]))
18-
for n in range(k+1, nt)
19-
Dagger.@spawn BLAS.trsm!('L', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[k, n]))
20-
end
21-
for m in range(k+1, mt)
22-
if iscomplex
23-
Dagger.@spawn BLAS.herk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
24-
else
25-
Dagger.@spawn BLAS.syrk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
23+
info = [convert(LinearAlgebra.BlasInt, 0)]
24+
try
25+
Dagger.spawn_datadeps() do
26+
for k in range(1, mt)
27+
Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info))
28+
for n in range(k+1, nt)
29+
Dagger.@spawn BLAS.trsm!('L', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[k, n]))
2630
end
27-
for n in range(m+1, nt)
28-
Dagger.@spawn BLAS.gemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
31+
for m in range(k+1, mt)
32+
if iscomplex
33+
Dagger.@spawn BLAS.herk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
34+
else
35+
Dagger.@spawn BLAS.syrk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
36+
end
37+
for n in range(m+1, nt)
38+
Dagger.@spawn BLAS.gemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
39+
end
2940
end
3041
end
3142
end
43+
catch err
44+
err isa ThunkFailedException || rethrow()
45+
err = Dagger.Sch.unwrap_nested_exception(err.ex)
46+
err isa PosDefException || rethrow()
3247
end
3348

34-
return UpperTriangular(A), convert(LinearAlgebra.BlasInt, 0)
49+
return UpperTriangular(A), info[1]
3550
end
3651
function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{LowerTriangular}) where T
3752
LinearAlgebra.checksquare(A)
3853

3954
zone = one(T)
4055
mzone = -one(T)
56+
rzone = one(real(T))
57+
rmzone = -one(real(T))
4158
uplo = 'L'
4259
Ac = A.chunks
4360
mt, nt = size(Ac)
61+
iscomplex = T <: Complex
62+
trans = iscomplex ? 'C' : 'T'
4463

45-
Dagger.spawn_datadeps() do
46-
for k in range(1, mt)
47-
Dagger.@spawn LAPACK.potrf!(uplo, InOut(Ac[k, k]))
48-
for m in range(k+1, mt)
49-
Dagger.@spawn BLAS.trsm!('R', uplo, 'T', 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]))
50-
end
51-
for n in range(k+1, nt)
52-
Dagger.@spawn BLAS.syrk!(uplo, 'N', mzone, In(Ac[n, k]), zone, InOut(Ac[n, n]))
53-
for m in range(n+1, mt)
54-
Dagger.@spawn BLAS.gemm!('N', 'T', mzone, In(Ac[m, k]), In(Ac[n, k]), zone, InOut(Ac[m, n]))
64+
info = [convert(LinearAlgebra.BlasInt, 0)]
65+
try
66+
Dagger.spawn_datadeps() do
67+
for k in range(1, mt)
68+
Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info))
69+
for m in range(k+1, mt)
70+
Dagger.@spawn BLAS.trsm!('R', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]))
71+
end
72+
for n in range(k+1, nt)
73+
if iscomplex
74+
Dagger.@spawn BLAS.herk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]))
75+
else
76+
Dagger.@spawn BLAS.syrk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]))
77+
end
78+
for m in range(n+1, mt)
79+
Dagger.@spawn BLAS.gemm!('N', trans, mzone, In(Ac[m, k]), In(Ac[n, k]), zone, InOut(Ac[m, n]))
80+
end
5581
end
5682
end
5783
end
84+
catch err
85+
err isa ThunkFailedException || rethrow()
86+
err = Dagger.Sch.unwrap_nested_exception(err.ex)
87+
err isa PosDefException || rethrow()
5888
end
5989

60-
return LowerTriangular(A), convert(LinearAlgebra.BlasInt, 0)
90+
return LowerTriangular(A), info[1]
6191
end

Diff for: test/linalg.jl

+13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
A = rand(T, 128, 128)
1010
A = A * A'
11+
A[diagind(A)] .+= size(A, 1)
1112
DA = view(A, Blocks(32, 32))
1213
if !(T <: Complex)
1314
@test issymmetric(DA)
@@ -30,5 +31,17 @@
3031
@test chol_A.U chol_DA.U
3132
# Check that changes propagated to A
3233
@test UpperTriangular(collect(DA)) UpperTriangular(collect(A))
34+
35+
# Non-PosDef matrix
36+
A = rand(T, 128, 128)
37+
A = A * A'
38+
A[diagind(A)] .+= size(A, 1)
39+
A[1, 1] = -100
40+
DA = view(A, Blocks(32, 32))
41+
if !(T <: Complex)
42+
@test issymmetric(DA)
43+
end
44+
@test ishermitian(DA)
45+
@test_throws_unwrap PosDefException cholesky(DA).U
3346
end
3447
end

0 commit comments

Comments
 (0)