Skip to content

Commit 6ec86a1

Browse files
authored
Merge pull request #565 from JuliaParallel/rabab/lu
LU implementation for DArray
2 parents e7eae88 + 38c4a9b commit 6ec86a1

File tree

9 files changed

+81
-21
lines changed

9 files changed

+81
-21
lines changed

Diff for: .buildkite/pipeline.yml

+8-18
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,6 @@
1515
arch: x86_64
1616
num_cpus: 16
1717
steps:
18-
- label: Julia 1.8
19-
timeout_in_minutes: 90
20-
<<: *test
21-
plugins:
22-
- JuliaCI/julia#v1:
23-
version: "1.8"
24-
- JuliaCI/julia-test#v1:
25-
julia_args: "--threads=1"
26-
- JuliaCI/julia-coverage#v1:
27-
codecov: true
2818
- label: Julia 1.9
2919
timeout_in_minutes: 90
3020
<<: *test
@@ -55,7 +45,7 @@ steps:
5545
julia_args: "--threads=1"
5646
- JuliaCI/julia-coverage#v1:
5747
codecov: true
58-
- label: Julia 1.8 (macOS)
48+
- label: Julia 1.9 (macOS)
5949
timeout_in_minutes: 90
6050
<<: *test
6151
agents:
@@ -64,26 +54,26 @@ steps:
6454
arch: x86_64
6555
plugins:
6656
- JuliaCI/julia#v1:
67-
version: "1.8"
57+
version: "1.9"
6858
- JuliaCI/julia-test#v1:
6959
julia_args: "--threads=1"
7060
- JuliaCI/julia-coverage#v1:
7161
codecov: true
72-
- label: Julia 1.8 - TimespanLogging
62+
- label: Julia 1.9 - TimespanLogging
7363
timeout_in_minutes: 20
7464
<<: *test
7565
plugins:
7666
- JuliaCI/julia#v1:
77-
version: "1.8"
67+
version: "1.9"
7868
- JuliaCI/julia-coverage#v1:
7969
codecov: true
8070
command: "julia --project -e 'using Pkg; Pkg.instantiate(); Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.test(\"TimespanLogging\")'"
81-
- label: Julia 1.8 - DaggerWebDash
71+
- label: Julia 1.9 - DaggerWebDash
8272
timeout_in_minutes: 20
8373
<<: *test
8474
plugins:
8575
- JuliaCI/julia#v1:
86-
version: "1.8"
76+
version: "1.9"
8777
- JuliaCI/julia-coverage#v1:
8878
codecov: true
8979
command: "julia -e 'using Pkg; Pkg.develop(;path=pwd()); Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.develop(;path=\"lib/DaggerWebDash\"); include(\"lib/DaggerWebDash/test/runtests.jl\")'"
@@ -92,7 +82,7 @@ steps:
9282
<<: *bench
9383
plugins:
9484
- JuliaCI/julia#v1:
95-
version: "1.8"
85+
version: "1.9"
9686
- JuliaCI/julia-test#v1:
9787
run_tests: false
9888
command: "julia -e 'using Pkg; Pkg.add(\"BenchmarkTools\"); Pkg.develop(;path=pwd())'; JULIA_PROJECT=\"$PWD\" julia --project benchmarks/benchmark.jl"
@@ -107,7 +97,7 @@ steps:
10797
timeout_in_minutes: 20
10898
plugins:
10999
- JuliaCI/julia#v1:
110-
version: "1.8"
100+
version: "1.9"
111101
env:
112102
JULIA_NUM_THREADS: "4"
113103
agents:

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Statistics = "1"
5959
StatsBase = "0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
6060
TaskLocalValues = "0.1"
6161
TimespanLogging = "0.1"
62-
julia = "1.8"
62+
julia = "1.9"
6363

6464
[extras]
6565
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"

Diff for: docs/src/darray.md

+1
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,4 @@ From `LinearAlgebra`:
446446
- `*` (Out-of-place Matrix-(Matrix/Vector) multiply)
447447
- `mul!` (In-place Matrix-Matrix multiply)
448448
- `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization)
449+
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only))

Diff for: src/Dagger.jl

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ include("array/sort.jl")
8484
include("array/linalg.jl")
8585
include("array/mul.jl")
8686
include("array/cholesky.jl")
87+
include("array/lu.jl")
8788
include("array/random.jl")
8889

8990
# Visualization

Diff for: src/array/linalg.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function LinearAlgebra.norm2(A::DArray{T,2}) where T
22
Ac = A.chunks
3-
norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac]
4-
return sqrt(sum(map(fetch, norms)))
3+
norms = [Dagger.@spawn mapreduce(LinearAlgebra.norm_sqr, +, chunk) for chunk in Ac]::Matrix{DTask}
4+
return sqrt(sum(map(norm->fetch(norm)::real(T), norms)))
55
end
66
function LinearAlgebra.norm2(A::UpperTriangular{T,<:DArray{T,2}}) where T
77
Ac = parent(A).chunks

Diff for: src/array/lu.jl

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
function LinearAlgebra.lu(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=true) where T
2+
A_copy = LinearAlgebra._lucopy(A, LinearAlgebra.lutype(T))
3+
return LinearAlgebra.lu!(A_copy, LinearAlgebra.NoPivot(); check=check)
4+
end
5+
function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=true) where T
6+
zone = one(T)
7+
mzone = -one(T)
8+
Ac = A.chunks
9+
mt, nt = size(Ac)
10+
iscomplex = T <: Complex
11+
trans = iscomplex ? 'C' : 'T'
12+
13+
Dagger.spawn_datadeps() do
14+
for k in range(1, min(mt, nt))
15+
Dagger.@spawn LinearAlgebra.generic_lufact!(InOut(Ac[k, k]), LinearAlgebra.NoPivot(); check)
16+
for m in range(k+1, mt)
17+
Dagger.@spawn BLAS.trsm!('R', 'U', 'N', 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]))
18+
end
19+
for n in range(k+1, nt)
20+
Dagger.@spawn BLAS.trsm!('L', 'L', 'N', 'U', zone, In(Ac[k, k]), InOut(Ac[k, n]))
21+
for m in range(k+1, mt)
22+
Dagger.@spawn BLAS.gemm!('N', 'N', mzone, In(Ac[m, k]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
23+
end
24+
end
25+
end
26+
end
27+
28+
ipiv = DVector([i for i in 1:min(size(A)...)])
29+
30+
return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, 0)
31+
end

Diff for: test/array/linalg/cholesky.jl

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
A = rand(T, 128, 128)
99
A = A * A'
1010
A[diagind(A)] .+= size(A, 1)
11+
B = copy(A)
1112
DA = view(A, Blocks(32, 32))
1213
if !(T <: Complex)
1314
@test issymmetric(DA)
@@ -20,6 +21,8 @@
2021
@test chol_DA isa Cholesky
2122
@test chol_A.L chol_DA.L
2223
@test chol_A.U chol_DA.U
24+
# Check that cholesky did not modify A or DA
25+
@test A DA B
2326

2427
# In-place
2528
A_copy = copy(A)

Diff for: test/array/linalg/lu.jl

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
@testset "$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
2+
A = rand(T, 128, 128)
3+
B = copy(A)
4+
DA = view(A, Blocks(64, 64))
5+
6+
# Out-of-place
7+
lu_A = lu(A, NoPivot())
8+
lu_DA = lu(DA, NoPivot())
9+
@test lu_DA isa LU{T,DMatrix{T},DVector{Int}}
10+
if !(T in (Float32, ComplexF32)) # FIXME: NoPivot is unstable for FP32
11+
@test lu_A.L lu_DA.L
12+
@test lu_A.U lu_DA.U
13+
end
14+
@test lu_A.P lu_DA.P
15+
@test lu_A.p lu_DA.p
16+
# Check that lu did not modify A or DA
17+
@test A DA B
18+
19+
# In-place
20+
A_copy = copy(A)
21+
lu_A = lu!(A_copy, NoPivot())
22+
lu_DA = lu!(DA, NoPivot())
23+
@test lu_DA isa LU{T,DMatrix{T},DVector{Int}}
24+
if !(T in (Float32, ComplexF32)) # FIXME: NoPivot is unstable for FP32
25+
@test lu_A.L lu_DA.L
26+
@test lu_A.U lu_DA.U
27+
end
28+
@test lu_A.P lu_DA.P
29+
@test lu_A.p lu_DA.p
30+
# Check that changes propagated to A
31+
@test DA A
32+
@test !(B A)
33+
end

Diff for: test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ tests = [
1818
("Array - MapReduce", "array/mapreduce.jl"),
1919
("Array - LinearAlgebra - Matmul", "array/linalg/matmul.jl"),
2020
("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"),
21+
("Array - LinearAlgebra - LU", "array/linalg/lu.jl"),
2122
("Array - Random", "array/random.jl"),
2223
("Caching", "cache.jl"),
2324
("Disk Caching", "diskcaching.jl"),

0 commit comments

Comments
 (0)