Skip to content

Commit 571eb0d

Browse files
committed
Merge pull request #7236 from JuliaLang/anj/chol
Add generic Cholesky decomposition and make Cholesky parametric on matrix type
2 parents 9610786 + 5e88074 commit 571eb0d

File tree

7 files changed

+231
-157
lines changed

7 files changed

+231
-157
lines changed

base/linalg.jl

+1
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ include("linalg/dense.jl")
198198
include("linalg/tridiag.jl")
199199
include("linalg/triangular.jl")
200200
include("linalg/factorization.jl")
201+
include("linalg/cholesky.jl")
201202
include("linalg/lu.jl")
202203

203204
include("linalg/bunchkaufman.jl")

base/linalg/cholesky.jl

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
##########################
2+
# Cholesky Factorization #
3+
##########################
4+
immutable Cholesky{T,S<:AbstractMatrix{T},UpLo} <: Factorization{T}
5+
UL::S
6+
end
7+
immutable CholeskyPivoted{T} <: Factorization{T}
8+
UL::Matrix{T}
9+
uplo::Char
10+
piv::Vector{BlasInt}
11+
rank::BlasInt
12+
tol::Real
13+
info::BlasInt
14+
end
15+
16+
function chol!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U)
17+
C, info = LAPACK.potrf!(string(uplo)[1], A)
18+
return @assertposdef Triangular(C, uplo, false) info
19+
end
20+
21+
function chol!{T}(A::AbstractMatrix{T}, uplo::Symbol=:U)
22+
n = chksquare(A)
23+
@inbounds begin
24+
if uplo == :L
25+
for k = 1:n
26+
for i = 1:k - 1
27+
A[k,k] -= A[k,i]*A[k,i]'
28+
end
29+
A[k,k] = chol!(A[k,k], uplo)
30+
AkkInv = inv(A[k,k]')
31+
for j = 1:k
32+
for i = k + 1:n
33+
j == 1 && (A[i,k] = A[i,k]*AkkInv)
34+
j < k && (A[i,k] -= A[i,j]*A[k,j]'*AkkInv)
35+
end
36+
end
37+
end
38+
elseif uplo == :U
39+
for k = 1:n
40+
for i = 1:k - 1
41+
A[k,k] -= A[i,k]'A[i,k]
42+
end
43+
A[k,k] = chol!(A[k,k], uplo)
44+
AkkInv = inv(A[k,k])
45+
for j = k + 1:n
46+
for i = 1:k - 1
47+
A[k,j] -= A[i,k]'A[i,j]
48+
end
49+
A[k,j] = A[k,k]'\A[k,j]
50+
end
51+
end
52+
else
53+
throw(ArgumentError("uplo must be either :U or :L but was $(uplo)"))
54+
end
55+
end
56+
return Triangular(A, uplo, false)
57+
end
58+
59+
function cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0)
60+
uplochar = string(uplo)[1]
61+
if pivot
62+
A, piv, rank, info = LAPACK.pstrf!(uplochar, A, tol)
63+
return CholeskyPivoted{T}(A, uplochar, piv, rank, tol, info)
64+
end
65+
return Cholesky{T,typeof(A),uplo}(chol!(A, uplo).data)
66+
end
67+
cholfact!(A::AbstractMatrix, uplo::Symbol=:U) = Cholesky{eltype(A),typeof(A),uplo}(chol!(A, uplo).data)
68+
69+
cholfact{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0) = cholfact!(copy(A), uplo, pivot=pivot, tol=tol)
70+
function cholfact{T}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0)
71+
S = promote_type(typeof(chol(one(T))),Float32)
72+
S <: BlasFloat && return cholfact!(convert(AbstractMatrix{S}, A), uplo, pivot = pivot, tol = tol)
73+
pivot && throw(ArgumentError("pivot only supported for Float32, Float64, Complex{Float32} and Complex{Float64}"))
74+
S != T && return cholfact!(convert(AbstractMatrix{S}, A), uplo)
75+
return cholfact!(copy(A), uplo)
76+
end
77+
function cholfact(x::Number, uplo::Symbol=:U)
78+
xf = fill(chol!(x), 1, 1)
79+
Cholesky{:U, eltype(xf), typeof(xf)}(xf)
80+
end
81+
82+
chol(A::AbstractMatrix, uplo::Symbol=:U) = Triangular(chol!(copy(A), uplo), uplo, false)
83+
function chol!(x::Number, uplo::Symbol=:U)
84+
rx = real(x)
85+
rx == abs(x) || throw(DomainError())
86+
rxr = sqrt(rx)
87+
convert(promote_type(typeof(x), typeof(rxr)), rxr)
88+
end
89+
chol(x::Number, uplo::Symbol=:U) = chol!(x, uplo)
90+
91+
function convert{Tnew,Told,S,UpLo}(::Type{Cholesky{Tnew}},C::Cholesky{Told,S,UpLo})
92+
Cnew = convert(AbstractMatrix{Tnew}, C.UL)
93+
Cholesky{Tnew, typeof(Cnew), UpLo}(Cnew)
94+
end
95+
function convert{T,S,UpLo}(::Type{Cholesky{T,S,UpLo}},C::Cholesky)
96+
Cnew = convert(AbstractMatrix{T}, C.UL)
97+
Cholesky{T, typeof(Cnew), UpLo}(Cnew)
98+
end
99+
convert{T}(::Type{Factorization{T}}, C::Cholesky) = convert(Cholesky{T}, C)
100+
convert{T}(::Type{CholeskyPivoted{T}},C::CholeskyPivoted) = CholeskyPivoted(convert(AbstractMatrix{T},C.UL),C.uplo,C.piv,C.rank,C.tol,C.info)
101+
convert{T}(::Type{Factorization{T}}, C::CholeskyPivoted) = convert(CholeskyPivoted{T}, C)
102+
103+
full{T,S}(C::Cholesky{T,S,:U}) = C[:U]'C[:U]
104+
full{T,S}(C::Cholesky{T,S,:L}) = C[:L]*C[:L]'
105+
106+
size(C::Union(Cholesky, CholeskyPivoted)) = size(C.UL)
107+
size(C::Union(Cholesky, CholeskyPivoted), d::Integer) = size(C.UL,d)
108+
109+
function getindex{T,S,UpLo}(C::Cholesky{T,S,UpLo}, d::Symbol)
110+
d == :U && return Triangular(UpLo == d ? C.UL : C.UL',:U)
111+
d == :L && return Triangular(UpLo == d ? C.UL : C.UL',:L)
112+
d == :UL && return Triangular(C.UL, UpLo)
113+
throw(KeyError(d))
114+
end
115+
function getindex{T<:BlasFloat}(C::CholeskyPivoted{T}, d::Symbol)
116+
d == :U && return Triangular(symbol(C.uplo) == d ? C.UL : C.UL', :U)
117+
d == :L && return Triangular(symbol(C.uplo) == d ? C.UL : C.UL', :L)
118+
d == :p && return C.piv
119+
if d == :P
120+
n = size(C, 1)
121+
P = zeros(T, n, n)
122+
for i=1:n
123+
P[C.piv[i],i] = one(T)
124+
end
125+
return P
126+
end
127+
throw(KeyError(d))
128+
end
129+
130+
show{T,S<:AbstractMatrix,UpLo}(io::IO, C::Cholesky{T,S,UpLo}) = (println("$(typeof(C)) with factor:");show(io,C[UpLo]))
131+
132+
A_ldiv_B!{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:U}, B::StridedVecOrMat{T}) = LAPACK.potrs!('U', C.UL, B)
133+
A_ldiv_B!{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:L}, B::StridedVecOrMat{T}) = LAPACK.potrs!('L', C.UL, B)
134+
A_ldiv_B!{T,S<:AbstractMatrix}(C::Cholesky{T,S,:L}, B::StridedVecOrMat) = Ac_ldiv_B!(Triangular(C.UL, :L, false), A_ldiv_B!(Triangular(C.UL, :L, false), B))
135+
A_ldiv_B!{T,S<:AbstractMatrix}(C::Cholesky{T,S,:U}, B::StridedVecOrMat) = A_ldiv_B!(Triangular(C.UL, :U, false), Ac_ldiv_B!(Triangular(C.UL, :U, false), B))
136+
137+
function A_ldiv_B!{T<:BlasFloat}(C::CholeskyPivoted{T}, B::StridedVector{T})
138+
chkfullrank(C)
139+
ipermute!(LAPACK.potrs!(C.uplo, C.UL, permute!(B, C.piv)), C.piv)
140+
end
141+
function A_ldiv_B!{T<:BlasFloat}(C::CholeskyPivoted{T}, B::StridedMatrix{T})
142+
chkfullrank(C)
143+
n = size(C, 1)
144+
for i=1:size(B, 2)
145+
permute!(sub(B, 1:n, i), C.piv)
146+
end
147+
LAPACK.potrs!(C.uplo, C.UL, B)
148+
for i=1:size(B, 2)
149+
ipermute!(sub(B, 1:n, i), C.piv)
150+
end
151+
B
152+
end
153+
A_ldiv_B!(C::CholeskyPivoted, B::StridedVector) = C.uplo=='L' ? Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv]))[invperm(C.piv)] : A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv]))[invperm(C.piv)]
154+
A_ldiv_B!(C::CholeskyPivoted, B::StridedMatrix) = C.uplo=='L' ? Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv,:]))[invperm(C.piv),:] : A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv,:]))[invperm(C.piv),:]
155+
156+
function det{T,S,UpLo}(C::Cholesky{T,S,UpLo})
157+
dd = one(T)
158+
for i in 1:size(C.UL,1) dd *= abs2(C.UL[i,i]) end
159+
dd
160+
end
161+
162+
det{T}(C::CholeskyPivoted{T}) = C.rank<size(C.UL,1) ? real(zero(T)) : prod(abs2(diag(C.UL)))
163+
164+
function logdet{T,S,UpLo}(C::Cholesky{T,S,UpLo})
165+
dd = zero(T)
166+
for i in 1:size(C.UL,1) dd += log(C.UL[i,i]) end
167+
dd + dd # instead of 2.0dd which can change the type
168+
end
169+
170+
inv{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:U}) = copytri!(LAPACK.potri!('U', copy(C.UL)), 'U', true)
171+
inv{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:L}) = copytri!(LAPACK.potri!('L', copy(C.UL)), 'L', true)
172+
173+
function inv(C::CholeskyPivoted)
174+
chkfullrank(C)
175+
ipiv = invperm(C.piv)
176+
copytri!(LAPACK.potri!(C.uplo, copy(C.UL)), C.uplo, true)[ipiv, ipiv]
177+
end
178+
179+
chkfullrank(C::CholeskyPivoted) = C.rank<size(C.UL, 1) && throw(RankDeficientException(C.info))
180+
181+
rank(C::CholeskyPivoted) = C.rank

base/linalg/dense.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ vecnorm1{T<:BlasReal}(x::Union(Array{T},StridedVector{T})) =
6060
vecnorm2{T<:BlasFloat}(x::Union(Array{T},StridedVector{T})) =
6161
length(x) < NRM2_CUTOFF ? generic_vecnorm2(x) : BLAS.nrm2(x)
6262

63-
function triu!{T}(M::Matrix{T}, k::Integer)
63+
function triu!(M::AbstractMatrix, k::Integer)
6464
m, n = size(M)
6565
idx = 1
6666
for j = 0:n-1
6767
ii = min(max(0, j+1-k), m)
6868
for i = (idx+ii):(idx+m-1)
69-
M[i] = zero(T)
69+
M[i] = zero(M[i])
7070
end
7171
idx += m
7272
end
@@ -75,13 +75,13 @@ end
7575

7676
triu(M::Matrix, k::Integer) = triu!(copy(M), k)
7777

78-
function tril!{T}(M::Matrix{T}, k::Integer)
78+
function tril!(M::AbstractMatrix, k::Integer)
7979
m, n = size(M)
8080
idx = 1
8181
for j = 0:n-1
8282
ii = min(max(0, j-k), m)
8383
for i = idx:(idx+ii-1)
84-
M[i] = zero(T)
84+
M[i] = zero(M[i])
8585
end
8686
idx += m
8787
end

base/linalg/factorization.jl

-120
Original file line numberDiff line numberDiff line change
@@ -10,126 +10,6 @@ macro assertnonsingular(A, info)
1010
:(($info)==0 ? $A : throw(SingularException($info)))
1111
end
1212

13-
##########################
14-
# Cholesky Factorization #
15-
##########################
16-
immutable Cholesky{T} <: Factorization{T}
17-
UL::Matrix{T}
18-
uplo::Char
19-
end
20-
immutable CholeskyPivoted{T} <: Factorization{T}
21-
UL::Matrix{T}
22-
uplo::Char
23-
piv::Vector{BlasInt}
24-
rank::BlasInt
25-
tol::Real
26-
info::BlasInt
27-
end
28-
29-
function cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0)
30-
uplochar = string(uplo)[1]
31-
if pivot
32-
A, piv, rank, info = LAPACK.pstrf!(uplochar, A, tol)
33-
return CholeskyPivoted{T}(A, uplochar, piv, rank, tol, info)
34-
else
35-
C, info = LAPACK.potrf!(uplochar, A)
36-
return @assertposdef Cholesky(C, uplochar) info
37-
end
38-
end
39-
cholfact{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0) = cholfact!(copy(A), uplo, pivot=pivot, tol=tol)
40-
cholfact{T}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0) = (S = promote_type(typeof(sqrt(one(T))),Float32); S != T ? cholfact!(convert(AbstractMatrix{S},A), uplo, pivot=pivot, tol=tol) : cholfact!(copy(A), uplo, pivot=pivot, tol=tol)) # When julia Cholesky has been implemented, the promotion should be changed.
41-
cholfact(x::Number) = @assertposdef Cholesky(fill(sqrt(x), 1, 1), :U) !(imag(x) == 0 && real(x) > 0)
42-
43-
chol(A::Union(Number, AbstractMatrix), uplo::Symbol) = cholfact(A, uplo)[uplo]
44-
chol(A::Union(Number, AbstractMatrix)) = triu!(cholfact(A, :U).UL)
45-
46-
convert{T}(::Type{Cholesky{T}},C::Cholesky) = Cholesky(convert(AbstractMatrix{T},C.UL),C.uplo)
47-
convert{T}(::Type{Factorization{T}}, C::Cholesky) = convert(Cholesky{T}, C)
48-
convert{T}(::Type{CholeskyPivoted{T}},C::CholeskyPivoted) = CholeskyPivoted(convert(AbstractMatrix{T},C.UL),C.uplo,C.piv,C.rank,C.tol,C.info)
49-
convert{T}(::Type{Factorization{T}}, C::CholeskyPivoted) = convert(CholeskyPivoted{T}, C)
50-
51-
function full{T<:BlasFloat}(C::Cholesky{T})
52-
if C.uplo == 'U'
53-
BLAS.trmm!('R', C.uplo, 'N', 'N', one(T), C.UL, tril!(C.UL'))
54-
else
55-
BLAS.trmm!('L', C.uplo, 'N', 'N', one(T), C.UL, triu!(C.UL'))
56-
end
57-
end
58-
59-
size(C::Union(Cholesky, CholeskyPivoted)) = size(C.UL)
60-
size(C::Union(Cholesky, CholeskyPivoted), d::Integer) = size(C.UL,d)
61-
62-
function getindex(C::Cholesky, d::Symbol)
63-
d == :U && return Triangular(triu!(symbol(C.uplo) == d ? C.UL : C.UL'),:U)
64-
d == :L && return Triangular(tril!(symbol(C.uplo) == d ? C.UL : C.UL'),:L)
65-
d == :UL && return Triangular(C.UL, symbol(C.uplo))
66-
throw(KeyError(d))
67-
end
68-
function getindex{T<:BlasFloat}(C::CholeskyPivoted{T}, d::Symbol)
69-
d == :U && return triu!(symbol(C.uplo) == d ? C.UL : C.UL')
70-
d == :L && return tril!(symbol(C.uplo) == d ? C.UL : C.UL')
71-
d == :p && return C.piv
72-
if d == :P
73-
n = size(C, 1)
74-
P = zeros(T, n, n)
75-
for i=1:n
76-
P[C.piv[i],i] = one(T)
77-
end
78-
return P
79-
end
80-
throw(KeyError(d))
81-
end
82-
83-
show(io::IO, C::Cholesky) = (println(io,"$(typeof(C)) with factor:");show(io,C[symbol(C.uplo)]))
84-
85-
A_ldiv_B!{T<:BlasFloat}(C::Cholesky{T}, B::StridedVecOrMat{T}) = LAPACK.potrs!(C.uplo, C.UL, B)
86-
A_ldiv_B!(C::Cholesky, B::StridedVecOrMat) = C.uplo=='L' ? Ac_ldiv_B!(Triangular(C.UL,C.uplo,'N'), A_ldiv_B!(Triangular(C.UL,C.uplo,'N'), B)) : A_ldiv_B!(Triangular(C.UL,C.uplo,'N'), Ac_ldiv_B!(Triangular(C.UL,C.uplo,'N'), B))
87-
88-
function A_ldiv_B!{T<:BlasFloat}(C::CholeskyPivoted{T}, B::StridedVector{T})
89-
chkfullrank(C)
90-
ipermute!(LAPACK.potrs!(C.uplo, C.UL, permute!(B, C.piv)), C.piv)
91-
end
92-
function A_ldiv_B!{T<:BlasFloat}(C::CholeskyPivoted{T}, B::StridedMatrix{T})
93-
chkfullrank(C)
94-
n = size(C, 1)
95-
for i=1:size(B, 2)
96-
permute!(sub(B, 1:n, i), C.piv)
97-
end
98-
LAPACK.potrs!(C.uplo, C.UL, B)
99-
for i=1:size(B, 2)
100-
ipermute!(sub(B, 1:n, i), C.piv)
101-
end
102-
B
103-
end
104-
A_ldiv_B!(C::CholeskyPivoted, B::StridedVector) = C.uplo=='L' ? Ac_ldiv_B!(Triangular(C.UL,C.uplo,'N'), A_ldiv_B!(Triangular(C.UL,C.uplo,'N'), B[C.piv]))[invperm(C.piv)] : A_ldiv_B!(Triangular(C.UL,C.uplo,'N'), Ac_ldiv_B!(Triangular(C.UL,C.uplo,'N'), B[C.piv]))[invperm(C.piv)]
105-
A_ldiv_B!(C::CholeskyPivoted, B::StridedMatrix) = C.uplo=='L' ? Ac_ldiv_B!(Triangular(C.UL,C.uplo,'N'), A_ldiv_B!(Triangular(C.UL,C.uplo,'N'), B[C.piv,:]))[invperm(C.piv),:] : A_ldiv_B!(Triangular(C.UL,C.uplo,'N'), Ac_ldiv_B!(Triangular(C.UL,C.uplo,'N'), B[C.piv,:]))[invperm(C.piv),:]
106-
107-
function det{T}(C::Cholesky{T})
108-
dd = one(T)
109-
for i in 1:size(C.UL,1) dd *= abs2(C.UL[i,i]) end
110-
dd
111-
end
112-
113-
det{T}(C::CholeskyPivoted{T}) = C.rank<size(C.UL,1) ? real(zero(T)) : prod(abs2(diag(C.UL)))
114-
115-
function logdet{T}(C::Cholesky{T})
116-
dd = zero(T)
117-
for i in 1:size(C.UL,1) dd += log(C.UL[i,i]) end
118-
dd + dd # instead of 2.0dd which can change the type
119-
end
120-
121-
inv(C::Cholesky) = copytri!(LAPACK.potri!(C.uplo, copy(C.UL)), C.uplo, true)
122-
123-
function inv(C::CholeskyPivoted)
124-
chkfullrank(C)
125-
ipiv = invperm(C.piv)
126-
copytri!(LAPACK.potri!(C.uplo, copy(C.UL)), C.uplo, true)[ipiv, ipiv]
127-
end
128-
129-
chkfullrank(C::CholeskyPivoted) = C.rank<size(C.UL, 1) && throw(RankDeficientException(C.info))
130-
131-
rank(C::CholeskyPivoted) = C.rank
132-
13313
####################
13414
# QR Factorization #
13515
####################

base/linalg/triangular.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ function similar{T,S,UpLo,IsUnit,Tnew}(A::Triangular{T,S,UpLo,IsUnit}, ::Type{Tn
133133
return Triangular{Tnew, typeof(A), UpLo, IsUnit}(A)
134134
end
135135

136-
getindex{T,S}(A::Triangular{T,S,:L,true}, i::Integer, j::Integer) = i == j ? one(T) : (i > j ? A.data[i,j] : zero(T))
137-
getindex{T,S}(A::Triangular{T,S,:L,false}, i::Integer, j::Integer) = i >= j ? A.data[i,j] : zero(T)
138-
getindex{T,S}(A::Triangular{T,S,:U,true}, i::Integer, j::Integer) = i == j ? one(T) : (i < j ? A.data[i,j] : zero(T))
139-
getindex{T,S}(A::Triangular{T,S,:U,false}, i::Integer, j::Integer) = i <= j ? A.data[i,j] : zero(T)
140-
getindex{T,S,UpLo,IsUnit}(A::Triangular{T,S,UpLo,IsUnit}, i::Integer) = ((m, n) = divrem(i - 1, size(A,1)); A[m + 1, n + 1])
136+
getindex{T,S}(A::Triangular{T,S,:L,true}, i::Integer, j::Integer) = i == j ? one(T) : (i > j ? A.data[i,j] : zero(A.data[i,j]))
137+
getindex{T,S}(A::Triangular{T,S,:L,false}, i::Integer, j::Integer) = i >= j ? A.data[i,j] : zero(A.data[i,j])
138+
getindex{T,S}(A::Triangular{T,S,:U,true}, i::Integer, j::Integer) = i == j ? one(T) : (i < j ? A.data[i,j] : zero(A.data[i,j]))
139+
getindex{T,S}(A::Triangular{T,S,:U,false}, i::Integer, j::Integer) = i <= j ? A.data[i,j] : zero(A.data[i,j])
140+
getindex(A::Triangular, i::Integer) = ((m, n) = divrem(i - 1, size(A,1)); A[m + 1, n + 1])
141141

142142
istril{T,S,UpLo,IsUnit}(A::Triangular{T,S,UpLo,IsUnit}) = UpLo == :L
143143
istriu{T,S,UpLo,IsUnit}(A::Triangular{T,S,UpLo,IsUnit}) = UpLo == :U

0 commit comments

Comments
 (0)