Skip to content

cosmetic fixes for LinAlg.chksquare #14601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions base/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,18 @@ function chkstride1(A...)
end
end

# Check that matrix is square
function chksquare(A)
"""
LinAlg.checksquare(A)

Check that a matrix is square, then return its common dimension. For multiple arguments, return a vector.
"""
function checksquare(A)
m,n = size(A)
m == n || throw(DimensionMismatch("matrix is not square"))
m
end

function chksquare(A...)
function checksquare(A...)
sizes = Int[]
for a in A
size(a,1)==size(a,2) || throw(DimensionMismatch("matrix is not square: dimensions are $(size(a))"))
Expand Down
2 changes: 1 addition & 1 deletion base/linalg/arnoldi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function _eigs(A, B;
tol=0.0, maxiter::Integer=300, sigma=nothing, v0::Vector=zeros(eltype(A),(0,)),
ritzvec::Bool=true)

n = chksquare(A)
n = checksquare(A)

T = eltype(A)
iscmplx = T <: Complex
Expand Down
26 changes: 13 additions & 13 deletions base/linalg/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export

const libblas = Base.libblas_name

import ..LinAlg: BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismatch, chksquare, axpy!
import ..LinAlg: BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismatch, checksquare, axpy!

"""
blas_set_num_threads(n)
Expand Down Expand Up @@ -682,7 +682,7 @@ for (fname, elty) in ((:dtrmv_,:Float64),
# * .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*)
function trmv!(uplo::Char, trans::Char, diag::Char, A::StridedMatrix{$elty}, x::StridedVector{$elty})
n = chksquare(A)
n = checksquare(A)
if n != length(x)
throw(DimensionMismatch("A has size ($n,$n), x has length $(length(x))"))
end
Expand Down Expand Up @@ -731,7 +731,7 @@ for (fname, elty) in ((:dtrsv_,:Float64),
# .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*)
function trsv!(uplo::Char, trans::Char, diag::Char, A::StridedMatrix{$elty}, x::StridedVector{$elty})
n = chksquare(A)
n = checksquare(A)
if n != length(x)
throw(DimensionMismatch("size of A is $n != length(x) = $(length(x))"))
end
Expand Down Expand Up @@ -795,7 +795,7 @@ for (fname, elty) in ((:dsyr_,:Float64),
(:csyr_,:Complex64))
@eval begin
function syr!(uplo::Char, α::$elty, x::StridedVector{$elty}, A::StridedMatrix{$elty})
n = chksquare(A)
n = checksquare(A)
if length(x) != n
throw(DimensionMismatch("A has size ($n,$n), x has length $(length(x))"))
end
Expand Down Expand Up @@ -824,7 +824,7 @@ for (fname, elty, relty) in ((:zher_,:Complex128, :Float64),
(:cher_,:Complex64, :Float32))
@eval begin
function her!(uplo::Char, α::$relty, x::StridedVector{$elty}, A::StridedMatrix{$elty})
n = chksquare(A)
n = checksquare(A)
if length(x) != n
throw(DimensionMismatch("A has size ($n,$n), x has length $(length(x))"))
end
Expand Down Expand Up @@ -922,7 +922,7 @@ for (mfname, elty) in ((:dsymm_,:Float64),
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function symm!(side::Char, uplo::Char, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
j = chksquare(A)
j = checksquare(A)
if j != (side == 'L' ? m : n)
throw(DimensionMismatch("A has size $(size(A)), C has size ($m,$n)"))
end
Expand Down Expand Up @@ -991,7 +991,7 @@ for (mfname, elty) in ((:zhemm_,:Complex128),
# DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
function hemm!(side::Char, uplo::Char, alpha::($elty), A::StridedMatrix{$elty}, B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
m, n = size(C)
j = chksquare(A)
j = checksquare(A)
if j != (side == 'L' ? m : n)
throw(DimensionMismatch("A has size $(size(A)), C has size ($m,$n)"))
end
Expand Down Expand Up @@ -1050,7 +1050,7 @@ for (fname, elty) in ((:dsyrk_,:Float64),
function syrk!(uplo::Char, trans::Char,
alpha::($elty), A::StridedVecOrMat{$elty},
beta::($elty), C::StridedMatrix{$elty})
n = chksquare(C)
n = checksquare(C)
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n throw(DimensionMismatch("C has size ($n,$n), corresponding dimension of A is $nn")) end
k = size(A, trans == 'N' ? 2 : 1)
Expand Down Expand Up @@ -1103,7 +1103,7 @@ for (fname, elty, relty) in ((:zherk_, :Complex128, :Float64),
# COMPLEX A(LDA,*),C(LDC,*)
function herk!(uplo::Char, trans::Char, α::$relty, A::StridedVecOrMat{$elty},
β::$relty, C::StridedMatrix{$elty})
n = chksquare(C)
n = checksquare(C)
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n
throw(DimensionMismatch("the matrix to update has dimension $n but the implied dimension of the update is $(size(A, trans == 'N' ? 1 : 2))"))
Expand Down Expand Up @@ -1144,7 +1144,7 @@ for (fname, elty) in ((:dsyr2k_,:Float64),
function syr2k!(uplo::Char, trans::Char,
alpha::($elty), A::StridedVecOrMat{$elty}, B::StridedVecOrMat{$elty},
beta::($elty), C::StridedMatrix{$elty})
n = chksquare(C)
n = checksquare(C)
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n throw(DimensionMismatch("C has size ($n,$n), corresponding dimension of A is $nn")) end
k = size(A, trans == 'N' ? 2 : 1)
Expand Down Expand Up @@ -1181,7 +1181,7 @@ for (fname, elty1, elty2) in ((:zher2k_,:Complex128,:Float64), (:cher2k_,:Comple
function her2k!(uplo::Char, trans::Char, alpha::($elty1),
A::StridedVecOrMat{$elty1}, B::StridedVecOrMat{$elty1},
beta::($elty2), C::StridedMatrix{$elty1})
n = chksquare(C)
n = checksquare(C)
nn = size(A, trans == 'N' ? 1 : 2)
if nn != n throw(DimensionMismatch("C has size ($n,$n), corresponding dimension of A is $nn")) end
k = size(A, trans == 'N' ? 2 : 1)
Expand Down Expand Up @@ -1258,7 +1258,7 @@ for (mmname, smname, elty) in
function trmm!(side::Char, uplo::Char, transa::Char, diag::Char, alpha::Number,
A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
m, n = size(B)
nA = chksquare(A)
nA = checksquare(A)
if nA != (side == 'L' ? m : n)
throw(DimensionMismatch("size of A, $(size(A)), doesn't match $side size of B with dims, $(size(B))"))
end
Expand All @@ -1283,7 +1283,7 @@ for (mmname, smname, elty) in
function trsm!(side::Char, uplo::Char, transa::Char, diag::Char,
alpha::$elty, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
m, n = size(B)
k = chksquare(A)
k = checksquare(A)
if k != (side == 'L' ? m : n)
throw(DimensionMismatch("size of A is $n, size(B)=($m,$n) and transa='$transa'"))
end
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end
chol!(A::StridedMatrix) = chol!(A, UpperTriangular)

function chol!{T}(A::AbstractMatrix{T}, ::Type{UpperTriangular})
n = chksquare(A)
n = checksquare(A)
@inbounds begin
for k = 1:n
for i = 1:k - 1
Expand All @@ -54,7 +54,7 @@ function chol!{T}(A::AbstractMatrix{T}, ::Type{UpperTriangular})
return UpperTriangular(A)
end
function chol!{T}(A::AbstractMatrix{T}, ::Type{LowerTriangular})
n = chksquare(A)
n = checksquare(A)
@inbounds begin
for k = 1:n
for i = 1:k - 1
Expand Down
14 changes: 7 additions & 7 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ end
diagm(x::Number) = (X = Array(typeof(x),1,1); X[1,1] = x; X)

function trace{T}(A::Matrix{T})
n = chksquare(A)
n = checksquare(A)
t = zero(T)
for i=1:n
t += A[i,i]
Expand Down Expand Up @@ -175,7 +175,7 @@ function ^(A::Matrix, p::Number)
if isinteger(p)
return A^Integer(real(p))
end
chksquare(A)
checksquare(A)
v, X = eig(A)
any(v.<0) && (v = complex(v))
Xinv = ishermitian(A) ? X' : inv(X)
Expand All @@ -190,7 +190,7 @@ expm(x::Number) = exp(x)
## Destructive matrix exponential using algorithm from Higham, 2008,
## "Functions of Matrices: Theory and Computation", SIAM
function expm!{T<:BlasFloat}(A::StridedMatrix{T})
n = chksquare(A)
n = checksquare(A)
if ishermitian(A)
return full(expm(Hermitian(A)))
end
Expand Down Expand Up @@ -308,7 +308,7 @@ function logm(A::StridedMatrix)
end

# Use Schur decomposition
n = chksquare(A)
n = checksquare(A)
if istriu(A)
retmat = full(logm(UpperTriangular(complex(A))))
d = diag(A)
Expand Down Expand Up @@ -343,7 +343,7 @@ function sqrtm{T<:Real}(A::StridedMatrix{T})
if issym(A)
return full(sqrtm(Symmetric(A)))
end
n = chksquare(A)
n = checksquare(A)
if istriu(A)
return full(sqrtm(UpperTriangular(A)))
else
Expand All @@ -356,7 +356,7 @@ function sqrtm{T<:Complex}(A::StridedMatrix{T})
if ishermitian(A)
return full(sqrtm(Hermitian(A)))
end
n = chksquare(A)
n = checksquare(A)
if istriu(A)
return full(sqrtm(UpperTriangular(A)))
else
Expand Down Expand Up @@ -511,7 +511,7 @@ function cond(A::AbstractMatrix, p::Real=2)
maxv = maximum(v)
return maxv == 0.0 ? oftype(real(A[1,1]),Inf) : maxv / minimum(v)
elseif p == 1 || p == Inf
chksquare(A)
checksquare(A)
return cond(lufact(A), p)
end
throw(ArgumentError("p-norm must be 1, 2 or Inf, got $p"))
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ end
rank(x::Number) = x==0 ? 0 : 1

function trace(A::AbstractMatrix)
chksquare(A)
checksquare(A)
sum(diag(A))
end
trace(x::Number) = x
Expand All @@ -347,7 +347,7 @@ trace(x::Number) = x
inv(a::StridedMatrix) = throw(ArgumentError("argument must be a square matrix"))
function inv{T}(A::AbstractMatrix{T})
S = typeof(zero(T)/one(T))
A_ldiv_B!(factorize(convert(AbstractMatrix{S}, A)), eye(S, chksquare(A)))
A_ldiv_B!(factorize(convert(AbstractMatrix{S}, A)), eye(S, checksquare(A)))
end

function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
Expand Down
Loading