Skip to content

Commit 1816efd

Browse files
Rabab53jpsamaroo
authored andcommitted
DArray: Implement LU factorization
1 parent e7eae88 commit 1816efd

File tree

5 files changed

+67
-0
lines changed

5 files changed

+67
-0
lines changed

docs/src/darray.md

+1
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,4 @@ From `LinearAlgebra`:
446446
- `*` (Out-of-place Matrix-(Matrix/Vector) multiply)
447447
- `mul!` (In-place Matrix-Matrix multiply)
448448
- `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization)
449+
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only))

src/Dagger.jl

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ include("array/sort.jl")
8484
include("array/linalg.jl")
8585
include("array/mul.jl")
8686
include("array/cholesky.jl")
87+
include("array/lu.jl")
8788
include("array/random.jl")
8889

8990
# Visualization

src/array/lu.jl

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
function LinearAlgebra.lu(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=true) where T
2+
A_copy = LinearAlgebra._lucopy(A, LinearAlgebra.lutype(T))
3+
return LinearAlgebra.lu!(A_copy, LinearAlgebra.NoPivot(); check=check)
4+
end
5+
function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=true) where T
6+
zone = one(T)
7+
mzone = -one(T)
8+
Ac = A.chunks
9+
mt, nt = size(Ac)
10+
iscomplex = T <: Complex
11+
trans = iscomplex ? 'C' : 'T'
12+
13+
Dagger.spawn_datadeps() do
14+
for k in range(1, min(mt, nt))
15+
Dagger.@spawn LinearAlgebra.generic_lufact!(InOut(Ac[k, k]), LinearAlgebra.NoPivot(); check)
16+
for m in range(k+1, mt)
17+
Dagger.@spawn BLAS.trsm!('R', 'U', 'N', 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]))
18+
end
19+
for n in range(k+1, nt)
20+
Dagger.@spawn BLAS.trsm!('L', 'L', 'N', 'U', zone, In(Ac[k, k]), InOut(Ac[k, n]))
21+
for m in range(k+1, mt)
22+
Dagger.@spawn BLAS.gemm!('N', 'N', mzone, In(Ac[m, k]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
23+
end
24+
end
25+
end
26+
end
27+
28+
ipiv = DVector([i for i in 1:min(size(A)...)])
29+
30+
return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, 0)
31+
end

test/array/linalg/lu.jl

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
@testset "$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
2+
A = rand(T, 128, 128)
3+
B = copy(A)
4+
DA = view(A, Blocks(64, 64))
5+
6+
# Out-of-place
7+
lu_A = lu(A, NoPivot())
8+
lu_DA = lu(DA, NoPivot())
9+
@test lu_DA isa LU{T,DMatrix{T},DVector{Int}}
10+
if !(T in (Float32, ComplexF32)) # FIXME: NoPivot is unstable for FP32
11+
@test lu_A.L lu_DA.L
12+
@test lu_A.U lu_DA.U
13+
end
14+
@test lu_A.P lu_DA.P
15+
@test lu_A.p lu_DA.p
16+
# Check that lu did not modify A or DA
17+
@test A DA B
18+
19+
# In-place
20+
A_copy = copy(A)
21+
lu_A = lu!(A_copy, NoPivot())
22+
lu_DA = lu!(DA, NoPivot())
23+
@test lu_DA isa LU{T,DMatrix{T},DVector{Int}}
24+
if !(T in (Float32, ComplexF32)) # FIXME: NoPivot is unstable for FP32
25+
@test lu_A.L lu_DA.L
26+
@test lu_A.U lu_DA.U
27+
end
28+
@test lu_A.P lu_DA.P
29+
@test lu_A.p lu_DA.p
30+
# Check that changes propagated to A
31+
@test DA A
32+
@test !(B A)
33+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ tests = [
1818
("Array - MapReduce", "array/mapreduce.jl"),
1919
("Array - LinearAlgebra - Matmul", "array/linalg/matmul.jl"),
2020
("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"),
21+
("Array - LinearAlgebra - LU", "array/linalg/lu.jl"),
2122
("Array - Random", "array/random.jl"),
2223
("Caching", "cache.jl"),
2324
("Disk Caching", "diskcaching.jl"),

0 commit comments

Comments
 (0)