Skip to content

Commit 74c5b11

Browse files
committed
Add qrfact(SparseMatrixCSC) by wrapping SPQR. Make \(SparseMatrixCSC) work for least squares problems. Some tests.
1 parent 3451d8f commit 74c5b11

File tree

5 files changed

+21
-12
lines changed

5 files changed

+21
-12
lines changed

base/linalg/generic.jl

+6-2
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,15 @@ function inv{T}(A::AbstractMatrix{T})
237237
A_ldiv_B!(factorize(convert(AbstractMatrix{S}, A)), eye(S, chksquare(A)))
238238
end
239239

240+
function \{T}(A::AbstractMatrix{T}, B::AbstractVecOrMat{T})
241+
size(A,1) == size(B,1) || throw(DimensionMismatch("LHS and RHS should have the same number of rows. LHS has $(size(A,1)) rows, but RHS has $(size(B,1)) rows."))
242+
factorize(A)\B
243+
end
240244
function \{TA,TB}(A::AbstractMatrix{TA}, B::AbstractVecOrMat{TB})
241245
TC = typeof(one(TA)/one(TB))
242-
size(A,1) == size(B,1) || throw(DimensionMismatch("LHS and RHS should have the same number of rows. LHS has $(size(A,1)) rows, but RHS has $(size(B,1)) rows."))
243-
\(factorize(TA == TC ? A : convert(AbstractMatrix{TC}, A)), TB == TC ? copy(B) : convert(AbstractArray{TC}, B))
246+
convert(AbstractMatrix{TC}, A)\convert(AbstractArray{TC}, B)
244247
end
248+
245249
\(a::AbstractVector, b::AbstractArray) = reshape(a, length(a), 1) \ b
246250
/(A::AbstractVecOrMat, B::AbstractVecOrMat) = (B' \ A')'
247251
# \(A::StridedMatrix,x::Number) = inv(A)*x Should be added at some point when the old elementwise version has been deprecated long enough

base/sparse.jl

+1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ include("sparse/csparse.jl")
1919
include("sparse/linalg.jl")
2020
include("sparse/umfpack.jl")
2121
include("sparse/cholmod.jl")
22+
include("sparse/spqr.jl")
2223

2324
end # module SparseMatrix

base/sparse/cholmod.jl

+10-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export
1313
Factor,
1414
Sparse
1515

16-
using Base.SparseMatrix: AbstractSparseMatrix, SparseMatrixCSC, increment, increment!, indtype, decrement, decrement!
16+
using Base.SparseMatrix: AbstractSparseMatrix, SparseMatrixCSC, increment, indtype
1717

1818
#########
1919
# Setup #
@@ -716,7 +716,7 @@ function Sparse(filename::ASCIIString)
716716
end
717717

718718
## convertion back to base Julia types
719-
function convert{T<:VTypes}(::Type{Matrix}, D::Dense{T})
719+
function convert{T}(::Type{Matrix{T}}, D::Dense{T})
720720
s = unsafe_load(D.p)
721721
a = Array(T, s.nrow, s.ncol)
722722
if s.d == s.nrow
@@ -730,6 +730,14 @@ function convert{T<:VTypes}(::Type{Matrix}, D::Dense{T})
730730
end
731731
a
732732
end
733+
convert{T}(::Type{Matrix}, D::Dense{T}) = convert(Matrix{T}, D)
734+
function convert{T}(::Type{Vector{T}}, D::Dense{T})
735+
if size(D, 2) > 1
736+
throw(DimensionMismatch("input must be a vector but had $(size(D, 2)) columnds"))
737+
end
738+
reshape(convert(Matrix, D), size(D, 1))
739+
end
740+
convert{T}(::Type{Vector}, D::Dense{T}) = convert(Vector{T}, D)
733741

734742
function convert{Tv,Ti}(::Type{SparseMatrixCSC{Tv,Ti}}, A::Sparse{Tv,Ti})
735743
s = unsafe_load(A.p)

base/sparse/linalg.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,10 @@ function factorize(A::SparseMatrixCSC)
686686
end
687687
end
688688
return lufact(A)
689+
elseif m > n
690+
return qrfact(A)
689691
else
690-
throw(ArgumentError("sparse least squares problems by QR are not handled yet"))
692+
throw(ArgumentError("underdetermined systemed are not implemented yet"))
691693
end
692694
end
693695

test/runtests.jl

+1-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ testnames = [
33
"linalg", "core", "keywordargs", "numbers", "strings", "dates",
44
"dict", "hashing", "remote", "iobuffer", "staged", "arrayops",
55
"subarray", "reduce", "reducedim", "random", "intfuncs",
6-
"simdloop", "blas", "fft", "dsp", "sparse", "bitarray", "copy", "math",
6+
"simdloop", "blas", "fft", "dsp", "sparsetests", "bitarray", "copy", "math",
77
"fastmath", "functional", "bigint", "sorting", "statistics", "spawn",
88
"backtrace", "priorityqueue", "arpack", "file", "version",
99
"resolve", "pollfd", "mpfr", "broadcast", "complex", "socket",
@@ -25,12 +25,6 @@ push!(testnames, "parallel")
2525

2626
tests = (ARGS==["all"] || isempty(ARGS)) ? testnames : ARGS
2727

28-
if "sparse" in tests
29-
# specifically selected case
30-
filter!(x -> x != "sparse", tests)
31-
prepend!(tests, ["sparse/sparse", "sparse/cholmod", "sparse/umfpack"])
32-
end
33-
3428
if "linalg" in tests
3529
# specifically selected case
3630
filter!(x -> x != "linalg", tests)

0 commit comments

Comments
 (0)