|
| 1 | +is_cross_symmetric(A1, A2) = A1 == A2' |
| 2 | +function LinearAlgebra.issymmetric(A::DArray{T,2}) where T |
| 3 | + Ac = A.chunks |
| 4 | + if size(Ac, 1) != size(Ac, 2) |
| 5 | + return false |
| 6 | + end |
| 7 | + |
| 8 | + to_check = [Dagger.@spawn issymmetric(Ac[i, i]) for i in 1:size(Ac, 1)] |
| 9 | + for i in 2:(size(Ac, 1)-1) |
| 10 | + j_pre_diag = i - 1 |
| 11 | + for j in 1:j_pre_diag |
| 12 | + push!(to_check, Dagger.@spawn is_cross_symmetric(Ac[i, j], Ac[j, i])) |
| 13 | + end |
| 14 | + end |
| 15 | + |
| 16 | + return all(fetch, to_check) |
| 17 | +end |
| 18 | +function LinearAlgebra.ishermitian(A::DArray{T,2}) where T |
| 19 | + Ac = A.chunks |
| 20 | + if size(Ac, 1) != size(Ac, 2) |
| 21 | + return false |
| 22 | + end |
| 23 | + |
| 24 | + to_check = [Dagger.@spawn ishermitian(Ac[i, i]) for i in 1:size(Ac, 1)] |
| 25 | + for i in 2:(size(Ac, 1)-1) |
| 26 | + j_pre_diag = i - 1 |
| 27 | + for j in 1:j_pre_diag |
| 28 | + push!(to_check, Dagger.@spawn is_cross_symmetric(Ac[i, j], Ac[j, i])) |
| 29 | + end |
| 30 | + end |
| 31 | + |
| 32 | + return all(fetch, to_check) |
| 33 | +end |
| 34 | + |
| 35 | +LinearAlgebra.cholcopy(A::DArray{T,2}) where T = copy(A) |
| 36 | +function LinearAlgebra.cholesky!(A::DArray{T,2}, ::NoPivot; check::Bool=true) where T |
| 37 | + if check |
| 38 | + ishermitian(A) || throw(PosDefException(-1)) |
| 39 | + end |
| 40 | + |
| 41 | + zone = one(T) |
| 42 | + mzone = -one(T) |
| 43 | + uplo = 'U' |
| 44 | + Ac = A.chunks |
| 45 | + mt, nt = size(Ac) |
| 46 | + |
| 47 | + if uplo == 'L' |
| 48 | + Dagger.spawn_datadeps() do |
| 49 | + for k in range(1, mt) |
| 50 | + Dagger.@spawn LAPACK.potrf!(uplo, InOut(Ac[k, k])) |
| 51 | + for m in range(k+1, mt) |
| 52 | + Dagger.@spawn BLAS.trsm!('R', 'L', 'T', 'N', zone, In(Ac[k, k]), InOut(Ac[m, k])) |
| 53 | + end |
| 54 | + for n in range(k+1, nt) |
| 55 | + Dagger.@spawn BLAS.syrk!(uplo, 'N', mzone, In(Ac[n, k]), zone, InOut(Ac[n, n])) |
| 56 | + for m in range(n+1, mt) |
| 57 | + Dagger.@spawn BLAS.gemm!('N', 'T', mzone, In(Ac[m, k]), In(Ac[n, k]), zone, InOut(Ac[m, n])) |
| 58 | + end |
| 59 | + end |
| 60 | + end |
| 61 | + end |
| 62 | + |
| 63 | + return Cholesky(A, 'L', 0) |
| 64 | + elseif uplo == 'U' |
| 65 | + Dagger.spawn_datadeps() do |
| 66 | + for k in range(1, mt) |
| 67 | + Dagger.@spawn LAPACK.potrf!(uplo, InOut(Ac[k, k])) |
| 68 | + for n in range(k+1, nt) |
| 69 | + Dagger.@spawn BLAS.trsm!('L', uplo, 'T', 'N', zone, In(Ac[k, k]), InOut(Ac[k, n])) |
| 70 | + end |
| 71 | + for m in range(k+1, mt) |
| 72 | + Dagger.@spawn BLAS.syrk!(uplo, 'T', mzone, In(Ac[k, m]), zone, InOut(Ac[m, m])) |
| 73 | + for n in range(m+1, nt) |
| 74 | + Dagger.@spawn BLAS.gemm!('T', 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n])) |
| 75 | + end |
| 76 | + end |
| 77 | + end |
| 78 | + end |
| 79 | + |
| 80 | + return Cholesky(A, 'U', 0) |
| 81 | + end |
| 82 | +end |
0 commit comments