|
| 1 | +########################## |
| 2 | +# Cholesky Factorization # |
| 3 | +########################## |
| 4 | +immutable Cholesky{T,S<:AbstractMatrix{T},UpLo} <: Factorization{T} |
| 5 | + UL::S |
| 6 | +end |
| 7 | +immutable CholeskyPivoted{T} <: Factorization{T} |
| 8 | + UL::Matrix{T} |
| 9 | + uplo::Char |
| 10 | + piv::Vector{BlasInt} |
| 11 | + rank::BlasInt |
| 12 | + tol::Real |
| 13 | + info::BlasInt |
| 14 | +end |
| 15 | + |
| 16 | +function chol!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U) |
| 17 | + C, info = LAPACK.potrf!(string(uplo)[1], A) |
| 18 | + return @assertposdef Triangular(C, uplo, false) info |
| 19 | +end |
| 20 | + |
| 21 | +function chol!{T}(A::AbstractMatrix{T}, uplo::Symbol=:U) |
| 22 | + n = chksquare(A) |
| 23 | + @inbounds begin |
| 24 | + if uplo == :L |
| 25 | + for k = 1:n |
| 26 | + for i = 1:k - 1 |
| 27 | + A[k,k] -= A[k,i]*A[k,i]' |
| 28 | + end |
| 29 | + A[k,k] = chol!(A[k,k], uplo) |
| 30 | + AkkInv = inv(A[k,k]') |
| 31 | + for j = 1:k |
| 32 | + for i = k + 1:n |
| 33 | + j == 1 && (A[i,k] = A[i,k]*AkkInv) |
| 34 | + j < k && (A[i,k] -= A[i,j]*A[k,j]'*AkkInv) |
| 35 | + end |
| 36 | + end |
| 37 | + end |
| 38 | + elseif uplo == :U |
| 39 | + for k = 1:n |
| 40 | + for i = 1:k - 1 |
| 41 | + A[k,k] -= A[i,k]'A[i,k] |
| 42 | + end |
| 43 | + A[k,k] = chol!(A[k,k], uplo) |
| 44 | + AkkInv = inv(A[k,k]) |
| 45 | + for j = k + 1:n |
| 46 | + for i = 1:k - 1 |
| 47 | + A[k,j] -= A[i,k]'A[i,j] |
| 48 | + end |
| 49 | + A[k,j] = A[k,k]'\A[k,j] |
| 50 | + end |
| 51 | + end |
| 52 | + else |
| 53 | + throw(ArgumentError("uplo must be either :U or :L but was $(uplo)")) |
| 54 | + end |
| 55 | + end |
| 56 | + return Triangular(A, uplo, false) |
| 57 | +end |
| 58 | + |
| 59 | +function cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0) |
| 60 | + uplochar = string(uplo)[1] |
| 61 | + if pivot |
| 62 | + A, piv, rank, info = LAPACK.pstrf!(uplochar, A, tol) |
| 63 | + return CholeskyPivoted{T}(A, uplochar, piv, rank, tol, info) |
| 64 | + end |
| 65 | + return Cholesky{T,typeof(A),uplo}(chol!(A, uplo).data) |
| 66 | +end |
| 67 | +cholfact!(A::AbstractMatrix, uplo::Symbol=:U) = Cholesky{eltype(A),typeof(A),uplo}(chol!(A, uplo).data) |
| 68 | + |
| 69 | +cholfact{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0) = cholfact!(copy(A), uplo, pivot=pivot, tol=tol) |
| 70 | +function cholfact{T}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0) |
| 71 | + S = promote_type(typeof(chol(one(T))),Float32) |
| 72 | + S <: BlasFloat && return cholfact!(convert(AbstractMatrix{S}, A), uplo, pivot = pivot, tol = tol) |
| 73 | + pivot && throw(ArgumentError("pivot only supported for Float32, Float64, Complex{Float32} and Complex{Float64}")) |
| 74 | + S != T && return cholfact!(convert(AbstractMatrix{S}, A), uplo) |
| 75 | + return cholfact!(copy(A), uplo) |
| 76 | +end |
| 77 | +function cholfact(x::Number, uplo::Symbol=:U) |
| 78 | + xf = fill(chol!(x), 1, 1) |
| 79 | + Cholesky{:U, eltype(xf), typeof(xf)}(xf) |
| 80 | +end |
| 81 | + |
| 82 | +chol(A::AbstractMatrix, uplo::Symbol=:U) = Triangular(chol!(copy(A), uplo), uplo, false) |
| 83 | +function chol!(x::Number, uplo::Symbol=:U) |
| 84 | + rx = real(x) |
| 85 | + rx == abs(x) || throw(DomainError()) |
| 86 | + rxr = sqrt(rx) |
| 87 | + convert(promote_type(typeof(x), typeof(rxr)), rxr) |
| 88 | +end |
| 89 | +chol(x::Number, uplo::Symbol=:U) = chol!(x, uplo) |
| 90 | + |
| 91 | +function convert{Tnew,Told,S,UpLo}(::Type{Cholesky{Tnew}},C::Cholesky{Told,S,UpLo}) |
| 92 | + Cnew = convert(AbstractMatrix{Tnew}, C.UL) |
| 93 | + Cholesky{Tnew, typeof(Cnew), UpLo}(Cnew) |
| 94 | +end |
| 95 | +function convert{T,S,UpLo}(::Type{Cholesky{T,S,UpLo}},C::Cholesky) |
| 96 | + Cnew = convert(AbstractMatrix{T}, C.UL) |
| 97 | + Cholesky{T, typeof(Cnew), UpLo}(Cnew) |
| 98 | +end |
| 99 | +convert{T}(::Type{Factorization{T}}, C::Cholesky) = convert(Cholesky{T}, C) |
| 100 | +convert{T}(::Type{CholeskyPivoted{T}},C::CholeskyPivoted) = CholeskyPivoted(convert(AbstractMatrix{T},C.UL),C.uplo,C.piv,C.rank,C.tol,C.info) |
| 101 | +convert{T}(::Type{Factorization{T}}, C::CholeskyPivoted) = convert(CholeskyPivoted{T}, C) |
| 102 | + |
| 103 | +full{T,S}(C::Cholesky{T,S,:U}) = C[:U]'C[:U] |
| 104 | +full{T,S}(C::Cholesky{T,S,:L}) = C[:L]*C[:L]' |
| 105 | + |
| 106 | +size(C::Union(Cholesky, CholeskyPivoted)) = size(C.UL) |
| 107 | +size(C::Union(Cholesky, CholeskyPivoted), d::Integer) = size(C.UL,d) |
| 108 | + |
| 109 | +function getindex{T,S,UpLo}(C::Cholesky{T,S,UpLo}, d::Symbol) |
| 110 | + d == :U && return Triangular(UpLo == d ? C.UL : C.UL',:U) |
| 111 | + d == :L && return Triangular(UpLo == d ? C.UL : C.UL',:L) |
| 112 | + d == :UL && return Triangular(C.UL, UpLo) |
| 113 | + throw(KeyError(d)) |
| 114 | +end |
| 115 | +function getindex{T<:BlasFloat}(C::CholeskyPivoted{T}, d::Symbol) |
| 116 | + d == :U && return Triangular(symbol(C.uplo) == d ? C.UL : C.UL', :U) |
| 117 | + d == :L && return Triangular(symbol(C.uplo) == d ? C.UL : C.UL', :L) |
| 118 | + d == :p && return C.piv |
| 119 | + if d == :P |
| 120 | + n = size(C, 1) |
| 121 | + P = zeros(T, n, n) |
| 122 | + for i=1:n |
| 123 | + P[C.piv[i],i] = one(T) |
| 124 | + end |
| 125 | + return P |
| 126 | + end |
| 127 | + throw(KeyError(d)) |
| 128 | +end |
| 129 | + |
| 130 | +show{T,S<:AbstractMatrix,UpLo}(io::IO, C::Cholesky{T,S,UpLo}) = (println("$(typeof(C)) with factor:");show(io,C[UpLo])) |
| 131 | + |
| 132 | +A_ldiv_B!{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:U}, B::StridedVecOrMat{T}) = LAPACK.potrs!('U', C.UL, B) |
| 133 | +A_ldiv_B!{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:L}, B::StridedVecOrMat{T}) = LAPACK.potrs!('L', C.UL, B) |
| 134 | +A_ldiv_B!{T,S<:AbstractMatrix}(C::Cholesky{T,S,:L}, B::StridedVecOrMat) = Ac_ldiv_B!(Triangular(C.UL, :L, false), A_ldiv_B!(Triangular(C.UL, :L, false), B)) |
| 135 | +A_ldiv_B!{T,S<:AbstractMatrix}(C::Cholesky{T,S,:U}, B::StridedVecOrMat) = A_ldiv_B!(Triangular(C.UL, :U, false), Ac_ldiv_B!(Triangular(C.UL, :U, false), B)) |
| 136 | + |
| 137 | +function A_ldiv_B!{T<:BlasFloat}(C::CholeskyPivoted{T}, B::StridedVector{T}) |
| 138 | + chkfullrank(C) |
| 139 | + ipermute!(LAPACK.potrs!(C.uplo, C.UL, permute!(B, C.piv)), C.piv) |
| 140 | +end |
| 141 | +function A_ldiv_B!{T<:BlasFloat}(C::CholeskyPivoted{T}, B::StridedMatrix{T}) |
| 142 | + chkfullrank(C) |
| 143 | + n = size(C, 1) |
| 144 | + for i=1:size(B, 2) |
| 145 | + permute!(sub(B, 1:n, i), C.piv) |
| 146 | + end |
| 147 | + LAPACK.potrs!(C.uplo, C.UL, B) |
| 148 | + for i=1:size(B, 2) |
| 149 | + ipermute!(sub(B, 1:n, i), C.piv) |
| 150 | + end |
| 151 | + B |
| 152 | +end |
| 153 | +A_ldiv_B!(C::CholeskyPivoted, B::StridedVector) = C.uplo=='L' ? Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv]))[invperm(C.piv)] : A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv]))[invperm(C.piv)] |
| 154 | +A_ldiv_B!(C::CholeskyPivoted, B::StridedMatrix) = C.uplo=='L' ? Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv,:]))[invperm(C.piv),:] : A_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), Ac_ldiv_B!(Triangular(C.UL, symbol(C.uplo), false), B[C.piv,:]))[invperm(C.piv),:] |
| 155 | + |
| 156 | +function det{T,S,UpLo}(C::Cholesky{T,S,UpLo}) |
| 157 | + dd = one(T) |
| 158 | + for i in 1:size(C.UL,1) dd *= abs2(C.UL[i,i]) end |
| 159 | + dd |
| 160 | +end |
| 161 | + |
| 162 | +det{T}(C::CholeskyPivoted{T}) = C.rank<size(C.UL,1) ? real(zero(T)) : prod(abs2(diag(C.UL))) |
| 163 | + |
| 164 | +function logdet{T,S,UpLo}(C::Cholesky{T,S,UpLo}) |
| 165 | + dd = zero(T) |
| 166 | + for i in 1:size(C.UL,1) dd += log(C.UL[i,i]) end |
| 167 | + dd + dd # instead of 2.0dd which can change the type |
| 168 | +end |
| 169 | + |
| 170 | +inv{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:U}) = copytri!(LAPACK.potri!('U', copy(C.UL)), 'U', true) |
| 171 | +inv{T<:BlasFloat,S<:AbstractMatrix}(C::Cholesky{T,S,:L}) = copytri!(LAPACK.potri!('L', copy(C.UL)), 'L', true) |
| 172 | + |
| 173 | +function inv(C::CholeskyPivoted) |
| 174 | + chkfullrank(C) |
| 175 | + ipiv = invperm(C.piv) |
| 176 | + copytri!(LAPACK.potri!(C.uplo, copy(C.UL)), C.uplo, true)[ipiv, ipiv] |
| 177 | +end |
| 178 | + |
| 179 | +chkfullrank(C::CholeskyPivoted) = C.rank<size(C.UL, 1) && throw(RankDeficientException(C.info)) |
| 180 | + |
| 181 | +rank(C::CholeskyPivoted) = C.rank |
0 commit comments