1
1
LinearAlgebra. cholcopy (A:: DArray{T,2} ) where T = copy (A)
2
+ function potrf_checked! (uplo, A, info_arr)
3
+ _A, info = LAPACK. potrf! (uplo, A)
4
+ if info > 0
5
+ info_arr[1 ] = info
6
+ throw (PosDefException (info))
7
+ end
8
+ return _A, info
9
+ end
2
10
function LinearAlgebra. _chol! (A:: DArray{T,2} , :: Type{UpperTriangular} ) where T
3
11
LinearAlgebra. checksquare (A)
4
12
@@ -12,50 +20,72 @@ function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{UpperTriangular}) where T
12
20
iscomplex = T <: Complex
13
21
trans = iscomplex ? ' C' : ' T'
14
22
15
- Dagger. spawn_datadeps () do
16
- for k in range (1 , mt)
17
- Dagger. @spawn LAPACK. potrf! (uplo, InOut (Ac[k, k]))
18
- for n in range (k+ 1 , nt)
19
- Dagger. @spawn BLAS. trsm! (' L' , uplo, trans, ' N' , zone, In (Ac[k, k]), InOut (Ac[k, n]))
20
- end
21
- for m in range (k+ 1 , mt)
22
- if iscomplex
23
- Dagger. @spawn BLAS. herk! (uplo, ' C' , rmzone, In (Ac[k, m]), rzone, InOut (Ac[m, m]))
24
- else
25
- Dagger. @spawn BLAS. syrk! (uplo, ' T' , rmzone, In (Ac[k, m]), rzone, InOut (Ac[m, m]))
23
+ info = [convert (LinearAlgebra. BlasInt, 0 )]
24
+ try
25
+ Dagger. spawn_datadeps () do
26
+ for k in range (1 , mt)
27
+ Dagger. @spawn potrf_checked! (uplo, InOut (Ac[k, k]), Out (info))
28
+ for n in range (k+ 1 , nt)
29
+ Dagger. @spawn BLAS. trsm! (' L' , uplo, trans, ' N' , zone, In (Ac[k, k]), InOut (Ac[k, n]))
26
30
end
27
- for n in range (m+ 1 , nt)
28
- Dagger. @spawn BLAS. gemm! (trans, ' N' , mzone, In (Ac[k, m]), In (Ac[k, n]), zone, InOut (Ac[m, n]))
31
+ for m in range (k+ 1 , mt)
32
+ if iscomplex
33
+ Dagger. @spawn BLAS. herk! (uplo, ' C' , rmzone, In (Ac[k, m]), rzone, InOut (Ac[m, m]))
34
+ else
35
+ Dagger. @spawn BLAS. syrk! (uplo, ' T' , rmzone, In (Ac[k, m]), rzone, InOut (Ac[m, m]))
36
+ end
37
+ for n in range (m+ 1 , nt)
38
+ Dagger. @spawn BLAS. gemm! (trans, ' N' , mzone, In (Ac[k, m]), In (Ac[k, n]), zone, InOut (Ac[m, n]))
39
+ end
29
40
end
30
41
end
31
42
end
43
+ catch err
44
+ err isa ThunkFailedException || rethrow ()
45
+ err = Dagger. Sch. unwrap_nested_exception (err. ex)
46
+ err isa PosDefException || rethrow ()
32
47
end
33
48
34
- return UpperTriangular (A), convert (LinearAlgebra . BlasInt, 0 )
49
+ return UpperTriangular (A), info[ 1 ]
35
50
end
36
51
function LinearAlgebra. _chol! (A:: DArray{T,2} , :: Type{LowerTriangular} ) where T
37
52
LinearAlgebra. checksquare (A)
38
53
39
54
zone = one (T)
40
55
mzone = - one (T)
56
+ rzone = one (real (T))
57
+ rmzone = - one (real (T))
41
58
uplo = ' L'
42
59
Ac = A. chunks
43
60
mt, nt = size (Ac)
61
+ iscomplex = T <: Complex
62
+ trans = iscomplex ? ' C' : ' T'
44
63
45
- Dagger. spawn_datadeps () do
46
- for k in range (1 , mt)
47
- Dagger. @spawn LAPACK. potrf! (uplo, InOut (Ac[k, k]))
48
- for m in range (k+ 1 , mt)
49
- Dagger. @spawn BLAS. trsm! (' R' , uplo, ' T' , ' N' , zone, In (Ac[k, k]), InOut (Ac[m, k]))
50
- end
51
- for n in range (k+ 1 , nt)
52
- Dagger. @spawn BLAS. syrk! (uplo, ' N' , mzone, In (Ac[n, k]), zone, InOut (Ac[n, n]))
53
- for m in range (n+ 1 , mt)
54
- Dagger. @spawn BLAS. gemm! (' N' , ' T' , mzone, In (Ac[m, k]), In (Ac[n, k]), zone, InOut (Ac[m, n]))
64
+ info = [convert (LinearAlgebra. BlasInt, 0 )]
65
+ try
66
+ Dagger. spawn_datadeps () do
67
+ for k in range (1 , mt)
68
+ Dagger. @spawn potrf_checked! (uplo, InOut (Ac[k, k]), Out (info))
69
+ for m in range (k+ 1 , mt)
70
+ Dagger. @spawn BLAS. trsm! (' R' , uplo, trans, ' N' , zone, In (Ac[k, k]), InOut (Ac[m, k]))
71
+ end
72
+ for n in range (k+ 1 , nt)
73
+ if iscomplex
74
+ Dagger. @spawn BLAS. herk! (uplo, ' N' , rmzone, In (Ac[n, k]), rzone, InOut (Ac[n, n]))
75
+ else
76
+ Dagger. @spawn BLAS. syrk! (uplo, ' N' , rmzone, In (Ac[n, k]), rzone, InOut (Ac[n, n]))
77
+ end
78
+ for m in range (n+ 1 , mt)
79
+ Dagger. @spawn BLAS. gemm! (' N' , trans, mzone, In (Ac[m, k]), In (Ac[n, k]), zone, InOut (Ac[m, n]))
80
+ end
55
81
end
56
82
end
57
83
end
84
+ catch err
85
+ err isa ThunkFailedException || rethrow ()
86
+ err = Dagger. Sch. unwrap_nested_exception (err. ex)
87
+ err isa PosDefException || rethrow ()
58
88
end
59
89
60
- return LowerTriangular (A), convert (LinearAlgebra . BlasInt, 0 )
90
+ return LowerTriangular (A), info[ 1 ]
61
91
end
0 commit comments