Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: factorize() and \ #3315

Merged
merged 1 commit into from
Jun 23, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ export
SVD,
GeneralizedSVD,
Hermitian,
Symmetric,
Triangular,
Diagonal,
InsertionSort,
Expand Down Expand Up @@ -566,6 +567,8 @@ export
checkbounds,

# linear algebra
bkfact,
bkfact!,
chol,
cholfact,
cholfact!,
Expand All @@ -591,6 +594,8 @@ export
expm,
sqrtm,
eye,
factorize,
factorize!,
hessfact,
hessfact!,
ishermitian,
Expand Down
6 changes: 6 additions & 0 deletions base/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ export
Schur,
SVD,
Hermitian,
Symmetric,
Triangular,
Diagonal,

# Functions
bkfact,
bkfact!,
check_openblas,
chol,
cholfact,
Expand Down Expand Up @@ -57,6 +60,8 @@ export
expm,
sqrtm,
eye,
factorize,
factorize!,
gradient,
hessfact,
hessfact!,
Expand Down Expand Up @@ -158,6 +163,7 @@ include("linalg/factorization.jl")
include("linalg/bunchkaufman.jl")
include("linalg/triangular.jl")
include("linalg/hermitian.jl")
include("linalg/symmetric.jl")
include("linalg/woodbury.jl")
include("linalg/tridiag.jl")
include("linalg/bidiag.jl")
Expand Down
6 changes: 3 additions & 3 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ promote_rule{T,S}(::Type{Tridiagonal{T}}, ::Type{Bidiagonal{S}})=Tridiagonal{pro

function show(io::IO, M::Bidiagonal)
println(io, summary(M), ":")
print(io, "diag: ")
print(io, " diag:")
print_matrix(io, (M.dv)')
print(io, M.isupper?"\n sup: ":"\n sub: ")
print(io, M.isupper?"\nsuper:":"\n sub:")
print_matrix(io, (M.ev)')
end

Expand Down Expand Up @@ -112,7 +112,7 @@ end
# solver uses tridiagonal gtsv!
function \{T<:BlasFloat}(M::Bidiagonal{T}, rhs::StridedVecOrMat{T})
if stride(rhs, 1) == 1
z = zeros(size(M)[1])
z = zeros(size(M, 1) - 1)
if M.isupper
return LAPACK.gtsv!(z, copy(M.dv), copy(M.ev), copy(rhs))
else
Expand Down
47 changes: 38 additions & 9 deletions base/linalg/bunchkaufman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,50 @@ type BunchKaufman{T<:BlasFloat} <: Factorization{T}
LD::Matrix{T}
ipiv::Vector{BlasInt}
uplo::Char
function BunchKaufman(A::Matrix{T}, uplo::Char)
LD, ipiv = LAPACK.sytrf!(uplo , copy(A))
new(LD, ipiv, uplo)
symmetric::Bool
end
function bkfact!{T<:BlasReal}(A::StridedMatrix{T}, uplo::Symbol)
LD, ipiv = LAPACK.sytrf!(string(uplo)[1] , A)
BunchKaufman(LD, ipiv, string(uplo)[1], true)
end
function bkfact!{T<:BlasReal}(A::StridedMatrix{T}, uplo::Symbol, symmetric::Bool)
if symmetric return bkfact!(A, uplo) end
error("The Bunch-Kaufman decomposition is only valid for symmetric matrices")
end
function bkfact!{T<:BlasComplex}(A::StridedMatrix{T}, uplo::Symbol, symmetric::Bool)
if symmetric
LD, ipiv = LAPACK.sytrf!(string(uplo)[1] , A)
else
LD, ipiv = LAPACK.hetrf!(string(uplo)[1] , A)
end
BunchKaufman(LD, ipiv, string(uplo)[1], symmetric)
end
BunchKaufman{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Char) = BunchKaufman{T}(A, uplo)
BunchKaufman{T<:Real}(A::StridedMatrix{T}, uplo::Char) = BunchKaufman(float64(A), uplo)
BunchKaufman{T<:Number}(A::StridedMatrix{T}) = BunchKaufman(A, 'U')
bkfact!{T<:BlasComplex}(A::StridedMatrix{T}, uplo::Symbol) = bkfact!(A, uplo, issym(A))
bkfact!(A::StridedMatrix, args...) = bkfact!(float(A), args...)
bkfact!{T<:BlasFloat}(A::StridedMatrix{T}) = bkfact!(A, :U)
bkfact{T<:BlasFloat}(A::StridedMatrix{T}, args...) = bkfact!(copy(A), args...)
bkfact(A::StridedMatrix, args...) = bkfact!(float(A), args...)

size(B::BunchKaufman) = size(B.LD)
size(B::BunchKaufman,d::Integer) = size(B.LD,d)
issym(B::BunchKaufman) = B.symmetric
ishermitian(B::BunchKaufman) = !B.symmetric

function inv(B::BunchKaufman)
function inv{T<:BlasReal}(B::BunchKaufman{T})
symmetrize_conj!(LAPACK.sytri!(B.uplo, copy(B.LD), B.ipiv), B.uplo)
end
function inv{T<:BlasComplex}(B::BunchKaufman{T})
if issym(B)
symmetrize!(LAPACK.sytri!(B.uplo, copy(B.LD), B.ipiv), B.uplo)
else
symmetrize_conj!(LAPACK.hetri!(B.uplo, copy(B.LD), B.ipiv), B.uplo)
end
end

\{T<:BlasFloat}(B::BunchKaufman{T}, R::StridedVecOrMat{T}) =
LAPACK.sytrs!(B.uplo, B.LD, B.ipiv, copy(R))
\{T<:BlasReal}(B::BunchKaufman{T}, R::StridedVecOrMat{T}) = LAPACK.sytrs!(B.uplo, B.LD, B.ipiv, copy(R))
function \{T<:BlasComplex}(B::BunchKaufman{T}, R::StridedVecOrMat{T})
if issym(B)
return LAPACK.sytrs!(B.uplo, B.LD, B.ipiv, copy(R))
end
return LAPACK.hetrs!(B.uplo, B.LD, B.ipiv, copy(R))
end
142 changes: 120 additions & 22 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,13 @@ expm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = expm
expm{T<:Integer}(A::StridedMatrix{T}) = expm!(float(A))
expm(x::Number) = exp(x)

function sqrtm(A::StridedMatrix, cond::Bool)
function sqrtm{T<:Real}(A::StridedMatrix{T}, cond::Bool)
m, n = size(A)
if m != n error("DimentionMismatch") end
if ishermitian(A)
return sqrtm(Hermitian(A), cond)
if issym(A)
return sqrtm(Symmetric(A), cond)
else
SchurF = schurfact!(iseltype(A,Complex) ? copy(A) : complex(A))
SchurF = schurfact!(complex(A))
R = zeros(eltype(SchurF[:T]), n, n)
for j = 1:n
R[j,j] = sqrt(SchurF[:T][j,j])
Expand All @@ -375,6 +375,35 @@ function sqrtm(A::StridedMatrix, cond::Bool)
return (all(imag(retmat) .== 0) ? real(retmat) : retmat)
end
end
function sqrtm{T<:Complex}(A::StridedMatrix{T}, cond::Bool)
m, n = size(A)
if m != n error("DimentionMismatch") end
if ishermitian(A)
return sqrtm(Hermitian(A), cond)
else
SchurF = schurfact(A)
R = zeros(eltype(SchurF[:T]), n, n)
for j = 1:n
R[j,j] = sqrt(SchurF[:T][j,j])
for i = j - 1:-1:1
r = SchurF[:T][i,j]
for k = i + 1:j - 1
r -= R[i,k]*R[k,j]
end
if r != 0
R[i,j] = r / (R[i,i] + R[j,j])
end
end
end
end
retmat = SchurF[:vectors]*R*SchurF[:vectors]'
if cond
alpha = norm(R)^2/norm(SchurF[:T])
return retmat, alpha
else
return retmat
end
end

sqrtm{T<:Integer}(A::StridedMatrix{T}, cond::Bool) = sqrtm(float(A), cond)
sqrtm{T<:Integer}(A::StridedMatrix{Complex{T}}, cond::Bool) = sqrtm(complex128(A), cond)
Expand All @@ -383,43 +412,112 @@ sqrtm(a::Number) = (b = sqrt(complex(a)); imag(b) == 0 ? real(b) : b)
sqrtm(a::Complex) = sqrt(a)

function det(A::Matrix)
m, n = size(A)
if m != n; throw(DimensionMismatch("det only defined for square matrices")); end
if istriu(A) | istril(A); return det(Triangular(A, :U, false)); end
return det(lufact(A))
end
det(x::Number) = x

logdet(A::Matrix) = logdet(cholfact(A))
logdet(A::Matrix) = logdet(lufact(A))

function inv{T<:BlasFloat}(A::AbstractMatrix{T})
if istriu(A) return inv(Triangular(A, :U)) end
if istril(A) return inv(Triangular(A, :L)) end
if ishermitian(A) return inv(Hermitian(A)) end
function inv(A::AbstractMatrix)
if istriu(A) | istril(A); return inv(Triangular(A, :U, false)); end
return inv(lufact(A))
end
inv(A::AbstractMatrix) = inv(float(A))

function (\){T<:BlasFloat}(A::StridedMatrix{T}, B::StridedVecOrMat{T})
if size(A, 1) == size(A, 2) # Square
if istriu(A) return Triangular(A, :U)\B end
if istril(A) return Triangular(A, :L)\B end
if ishermitian(A) return Hermitian(A)\B end
ans, _, _, info = LAPACK.gesv!(copy(A), copy(B))
if info > 0; throw(SingularException(info)); end
return ans
else
LAPACK.gelsy!(copy(A), copy(B))[1]
function factorize!{T}(A::Matrix{T})
m, n = size(A)
if m == n
if m == 1 return A[1] end
utri = true
utri1 = true
herm = T <: Complex
sym = true
for j = 1:n-1, i = j+1:m
if utri1
if A[i,j] != 0
utri1 = i == j + 1
utri = false
end
end
if sym
sym &= A[i,j] == A[j,i]
end
if (T <: Complex) & herm
herm &= A[i,j] == conj(A[j,i])
end
if !(utri1|herm|sym) break end
end
ltri = true
ltri1 = true
for j = 3:n, i = 1:j-2
ltri1 &= A[i,j] == 0
if !ltri1 break end
end
if ltri1
for i = 1:n-1
if A[i,i+1] != 0
ltri &= false
break
end
end
if ltri
if utri
return Diagonal(A)
end
if utri1
return lufact!(Bidiagonal(diag(A), diag(A, -1), false))
end
return Triangular(A, :L)
end
if utri
return lufact!(Bidiagonal(diag(A), diag(A, 1), true))
end
if utri1
if (herm & (T <: Complex)) | sym
return ldltd!(SymTridiagonal(diag(A), diag(A, -1)))
end
return lufact!(Tridiagonal(diag(A, -1), diag(A), diag(A, 1)))
end
end
if utri
return Triangular(A, :U)
end
if herm
C, info = LAPACK.potrf!('U', copy(A))
if info == 0 return Cholesky(C, 'U') end
return factorize!(Hermitian(A))
end
if sym
C, info = LAPACK.potrf!('U', copy(A))
if info == 0 return Cholesky(C, 'U') end
return factorize!(Symmetric(A))
end
return lufact!(A)
end
return qrpfact!(A)
end

factorize(A::AbstractMatrix) = factorize!(copy(A))

(\){T1<:BlasFloat, T2<:BlasFloat}(A::StridedMatrix{T1}, B::StridedVecOrMat{T2}) =
(\)(convert(Array{promote_type(T1,T2)},A), convert(Array{promote_type(T1,T2)},B))
(\){T1<:BlasFloat, T2<:Real}(A::StridedMatrix{T1}, B::StridedVecOrMat{T2}) = (\)(A, convert(Array{T1}, B))
(\){T1<:Real, T2<:BlasFloat}(A::StridedMatrix{T1}, B::StridedVecOrMat{T2}) = (\)(convert(Array{T2}, A), B)
(\){T1<:Real, T2<:Real}(A::StridedMatrix{T1}, B::StridedVecOrMat{T2}) = (\)(float64(A), float64(B))
(\){T1<:Number, T2<:Number}(A::StridedMatrix{T1}, B::StridedVecOrMat{T2}) = (\)(complex128(A), complex128(B))
(\)(a::Vector, B::StridedVecOrMat) = (\)(reshape(a, length(a), 1), B)
function (\){T<:BlasFloat}(A::StridedMatrix{T}, B::StridedVecOrMat{T})
m, n = size(A)
if m == n
if istril(A)
if istriu(A) return \(Diagonal(A),B) end
return \(Triangular(A, :L),B)
end
if istriu(A) return \(Triangular(A, :U),B) end
return \(lufact(A),B)
end
return qrpfact(A)\B
end

## Moore-Penrose inverse
function pinv{T<:BlasFloat}(A::StridedMatrix{T})
Expand Down
1 change: 1 addition & 0 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
type Diagonal{T} <: AbstractMatrix{T}
diag::Vector{T}
end
Diagonal(A::Matrix) = Diagonal(diag(A))

size(D::Diagonal) = (length(D.diag),length(D.diag))
size(D::Diagonal,d::Integer) = d<1 ? error("dimension out of range") : (d<=2 ? length(D.diag) : 1)
Expand Down
Loading