Skip to content

Commit 777b81d

Browse files
committed
Add getindex method to LU{Triangular} for extracting factors
1 parent e3dfa56 commit 777b81d

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

base/linalg/lu.jl

+30-11
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,18 @@ function ipiv2perm{T}(v::AbstractVector{T}, maxi::Integer)
113113
return p
114114
end
115115

116-
function getindex{T,S<:StridedMatrix}(A::LU{T,S}, d::Symbol)
117-
m, n = size(A)
116+
function getindex{T,S<:StridedMatrix}(F::LU{T,S}, d::Symbol)
117+
m, n = size(F)
118118
if d == :L
119-
L = tril!(A.factors[1:m, 1:min(m,n)])
119+
L = tril!(F.factors[1:m, 1:min(m,n)])
120120
for i = 1:min(m,n); L[i,i] = one(T); end
121121
return L
122122
elseif d == :U
123-
return triu!(A.factors[1:min(m,n), 1:n])
123+
return triu!(F.factors[1:min(m,n), 1:n])
124124
elseif d == :p
125-
return ipiv2perm(A.ipiv, m)
125+
return ipiv2perm(F.ipiv, m)
126126
elseif d == :P
127-
p = A[:p]
128-
P = zeros(T, m, m)
129-
for i in 1:m
130-
P[i,p[i]] = one(T)
131-
end
132-
return P
127+
return eye(T, m)[:,invperm(F[:p])]
133128
else
134129
throw(KeyError(d))
135130
end
@@ -263,6 +258,30 @@ end
263258

264259
factorize(A::Tridiagonal) = lufact(A)
265260

261+
function getindex{T}(F::Base.LinAlg.LU{T,Tridiagonal{T}}, d::Symbol)
262+
m, n = size(F)
263+
if d == :L
264+
L = full(Bidiagonal(ones(T, n), F.factors.dl, false))
265+
for i = 2:n
266+
tmp = L[F.ipiv[i], 1:i - 1]
267+
L[F.ipiv[i], 1:i - 1] = L[i, 1:i - 1]
268+
L[i, 1:i - 1] = tmp
269+
end
270+
return L
271+
elseif d == :U
272+
U = full(Bidiagonal(F.factors.d, F.factors.du, true))
273+
for i = 1:n - 2
274+
U[i,i + 2] = F.factors.du2[i]
275+
end
276+
return U
277+
elseif d == :p
278+
return ipiv2perm(F.ipiv, m)
279+
elseif d == :P
280+
return eye(T, m)[:,invperm(F[:p])]
281+
end
282+
throw(KeyError(d))
283+
end
284+
266285
# See dgtts2.f
267286
function A_ldiv_B!{T}(A::LU{T,Tridiagonal{T}}, B::AbstractVecOrMat)
268287
n = size(A,1)

test/linalg/lu.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
debug = false
44

5+
using Base.Test
6+
57
n = 10
68

79
# Split n into 2 parts for tests needing two matrices
@@ -57,6 +59,8 @@ debug && println("(Automatic) Square LU decomposition")
5759
debug && println("Tridiagonal LU")
5860
κd = cond(full(d),1)
5961
lud = lufact(d)
62+
@test_approx_eq lud[:L]*lud[:U] lud[:P]*full(d)
63+
@test_approx_eq lud[:L]*lud[:U] full(d)[lud[:p],:]
6064
@test norm(d*(lud\b) - b, 1) < ε*κd*n*2 # Two because the right hand side has two columns
6165
if eltya <: Real
6266
@test norm((lud.'\b) - full(d.')\b, 1) < ε*κd*n*2 # Two because the right hand side has two columns
@@ -112,6 +116,5 @@ for elty in (Float32, Float64, Complex64, Complex128)
112116
# @test norm(F[:vectors]*Diagonal(F[:values])/F[:vectors] - A) > 0.01
113117
end
114118

115-
116119
@test @inferred(logdet(Complex64[1.0f0 0.5f0; 0.5f0 -1.0f0])) === 0.22314355f0 + 3.1415927f0im
117120
@test_throws DomainError logdet([1 1; 1 -1])

0 commit comments

Comments
 (0)