Skip to content

Commit 81f5645

Browse files
authored
Merge pull request #20046 from Sacha0/degliumf
A[c|t]_ldiv_B! specializations for UmfpackLU-StridedVecOrMat, less generalized linear indexing and meta-fu
2 parents 0ed6ae3 + c536776 commit 81f5645

File tree

1 file changed

+47
-34
lines changed

1 file changed

+47
-34
lines changed

base/sparse/umfpack.jl

+47-34
Original file line numberDiff line numberDiff line change
@@ -383,43 +383,56 @@ function nnz(lu::UmfpackLU)
383383
end
384384

385385
### Solve with Factorization
386-
for (f!, umfpack) in ((:A_ldiv_B!, :UMFPACK_A),
387-
(:Ac_ldiv_B!, :UMFPACK_At),
388-
(:At_ldiv_B!, :UMFPACK_Aat))
389-
@eval begin
390-
function $f!{T<:UMFVTypes}(x::StridedVecOrMat{T}, lu::UmfpackLU{T}, b::StridedVecOrMat{T})
391-
n = size(x, 2)
392-
if n != size(b, 2)
393-
throw(DimensionMismatch("in and output arrays must have the same number of columns"))
394-
end
395-
for j in 1:n
396-
solve!(view(x, :, j), lu, view(b, :, j), $umfpack)
397-
end
398-
return x
399-
end
400-
$f!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVector{T}) = $f!(b, lu, copy(b))
401-
$f!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedMatrix{T}) = $f!(b, lu, copy(b))
402-
403-
function $f!{Tb<:Complex}(x::StridedVector{Tb}, lu::UmfpackLU{Float64}, b::StridedVector{Tb})
404-
m, n = size(x, 1), size(x, 2)
405-
if n != size(b, 2)
406-
throw(DimensionMismatch("in and output arrays must have the same number of columns"))
407-
end
408-
# TODO: Optionally let user allocate these and pass in somehow
409-
r = similar(b, Float64, m)
410-
i = similar(b, Float64, m)
411-
for j in 1:n
412-
solve!(r, lu, convert(Vector{Float64}, real(view(b, :, j))), $umfpack)
413-
solve!(i, lu, convert(Vector{Float64}, imag(view(b, :, j))), $umfpack)
414-
415-
map!((t,s) -> t + im*s, view(x, :, j), r, i)
416-
end
417-
return x
418-
end
419-
$f!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVector{Tb}) = $f!(b, lu, copy(b))
386+
A_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVecOrMat{T}) = A_ldiv_B!(b, lu, copy(b))
387+
At_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVecOrMat{T}) = At_ldiv_B!(b, lu, copy(b))
388+
Ac_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVecOrMat{T}) = Ac_ldiv_B!(b, lu, copy(b))
389+
390+
A_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
391+
_Aq_ldiv_B!(X, lu, B, UMFPACK_A)
392+
At_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
393+
_Aq_ldiv_B!(X, lu, B, UMFPACK_At)
394+
Ac_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
395+
_Aq_ldiv_B!(X, lu, B, UMFPACK_Aat)
396+
397+
A_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVecOrMat{Tb}) = A_ldiv_B!(b, lu, copy(b))
398+
At_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVecOrMat{Tb}) = At_ldiv_B!(b, lu, copy(b))
399+
Ac_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVecOrMat{Tb}) = Ac_ldiv_B!(b, lu, copy(b))
400+
401+
A_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
402+
_Aq_ldiv_B!(X, lu, B, UMFPACK_A)
403+
At_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
404+
_Aq_ldiv_B!(X, lu, B, UMFPACK_At)
405+
Ac_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
406+
_Aq_ldiv_B!(X, lu, B, UMFPACK_Aat)
407+
408+
_Aq_ldiv_B!(X::StridedVecOrMat, lu::UmfpackLU, B::StridedVecOrMat, transtype) =
409+
(_AqldivB_checkshapecompat(X, B); _AqldivB_kernel!(X, lu, B, transtype); return X)
410+
411+
_AqldivB_checkshapecompat(X::StridedVecOrMat, B::StridedVecOrMat) =
412+
size(X, 2) == size(B, 2) || throw(DimensionMismatch("input and output must have same column count"))
413+
414+
_AqldivB_kernel!{T<:UMFVTypes}(x::StridedVector{T}, lu::UmfpackLU{T}, b::StridedVector{T}, transtype) =
415+
solve!(x, lu, b, transtype)
416+
_AqldivB_kernel!{T<:UMFVTypes}(X::StridedMatrix{T}, lu::UmfpackLU{T}, B::StridedMatrix{T}, transtype) =
417+
for col in 1:size(X, 1) solve!(view(X, :, col), lu, view(B, :, col), transtype) end
418+
419+
function _AqldivB_kernel!{Tb<:Complex}(x::StridedVector{Tb}, lu::UmfpackLU{Float64}, b::StridedVector{Tb}, transtype)
420+
r, i = similar(b, Float64), similar(b, Float64)
421+
solve!(r, lu, Vector{Float64}(real(b)), transtype)
422+
solve!(i, lu, Vector{Float64}(imag(b)), transtype)
423+
map!(complex, x, r, i)
424+
end
425+
function _AqldivB_kernel!{Tb<:Complex}(X::StridedMatrix{Tb}, lu::UmfpackLU{Float64}, B::StridedMatrix{Tb}, transtype)
426+
r = similar(B, Float64, size(B, 1))
427+
i = similar(B, Float64, size(B, 1))
428+
for j in 1:size(B, 2)
429+
solve!(r, lu, Vector{Float64}(real(view(B, :, j))), transtype)
430+
solve!(i, lu, Vector{Float64}(imag(view(B, :, j))), transtype)
431+
map!(complex, view(X, :, j), r, i)
420432
end
421433
end
422434

435+
423436
function getindex(lu::UmfpackLU, d::Symbol)
424437
L,U,p,q,Rs = umf_extract(lu)
425438
if d == :L

0 commit comments

Comments
 (0)