Skip to content

Commit 0ed6ae3

Browse files
authored
Merge pull request #20043 from Sacha0/fixglidepwarns
A[c|t]_mul_B[!] specializations for SparseMatrixCSC-StridedVecOrMat, less generalized linear indexing and meta-fu
2 parents fa31b38 + ada593d commit 0ed6ae3

File tree

1 file changed

+65
-42
lines changed

1 file changed

+65
-42
lines changed

base/sparse/linalg.jl

+65-42
Original file line numberDiff line numberDiff line change
@@ -43,54 +43,77 @@ end
4343

4444
# In matrix-vector multiplication, the correct orientation of the vector is assumed.
4545

46-
for (f, op, transp) in ((:A_mul_B, :identity, false),
47-
(:Ac_mul_B, :ctranspose, true),
48-
(:At_mul_B, :transpose, true))
49-
@eval begin
50-
function $(Symbol(f,:!))(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
51-
if $transp
52-
A.n == size(C, 1) || throw(DimensionMismatch())
53-
A.m == size(B, 1) || throw(DimensionMismatch())
54-
else
55-
A.n == size(B, 1) || throw(DimensionMismatch())
56-
A.m == size(C, 1) || throw(DimensionMismatch())
57-
end
58-
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
59-
nzv = A.nzval
60-
rv = A.rowval
61-
if β != 1
62-
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C)))
63-
end
64-
for col = 1:A.n
65-
for k = 1:size(C, 2)
66-
if $transp
67-
tmp = zero(eltype(C))
68-
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1)
69-
tmp += $(op)(nzv[j])*B[rv[j],k]
70-
end
71-
C[col,k] += α*tmp
72-
else
73-
αxj = α*B[col,k]
74-
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1)
75-
C[rv[j], k] += nzv[j]*αxj
76-
end
77-
end
78-
end
79-
end
80-
C
81-
end
46+
A_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = A_mul_B!(_argstuple_AqmulB!(A, B)...)
47+
At_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = At_mul_B!(_argstuple_AqmulB!(A, B)...)
48+
Ac_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = Ac_mul_B!(_argstuple_AqmulB!(A, B)...)
49+
_argstuple_AqmulB!{TvA,TB}(A::SparseMatrixCSC{TvA}, b::StridedVector{TB}) =
50+
(R = promote_type(TvA, TB); (one(R), A, B, zero(R), similar(b, R, A.n)))
51+
_argstuple_AqmulB!{TvA,TB}(A::SparseMatrixCSC{TvA}, B::StridedMatrix{TB}) =
52+
(R = promote_type(TvA, TB); (one(R), A, B, zero(R), similar(B, R, (A.n, size(B,2)))))
53+
54+
A_mul_B!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) =
55+
_Aq_mul_B!(α, A, identity, B, β, C)
56+
At_mul_B!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) =
57+
_Aq_mul_B!(α, A, transpose, B, β, C)
58+
Ac_mul_B!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) =
59+
_Aq_mul_B!(α, A, ctranspose, B, β, C)
60+
61+
function _Aq_mul_B!::Number, A::SparseMatrixCSC, transopA::Function,
62+
B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
63+
_AqmulB_checkshapecompat(A, transopA, B, C)
64+
_AqmulB_specialscale!(C, β)
65+
_AqmulB_kernel!(α, A, transopA, B, C)
66+
return C
67+
end
8268

83-
function $(f){TA,S,Tx}(A::SparseMatrixCSC{TA,S}, x::StridedVector{Tx})
84-
T = promote_type(TA, Tx)
85-
$(Symbol(f,:!))(one(T), A, x, zero(T), similar(x, T, A.n))
69+
qtransposefntype = Union{typeof(transpose), typeof(ctranspose)}
70+
_AqmulB_checkshapecompat(A, ::typeof(identity), B, C) = _AqmulB_checkshapecompat(A.m, A.n, B, C)
71+
_AqmulB_checkshapecompat(A, ::qtransposefntype, B, C) = _AqmulB_checkshapecompat(A.n, A.m, B, C)
72+
function _AqmulB_checkshapecompat(Aqm, Aqn, B, C)
73+
size(B, 1) == Aqn || throw(DimensionMismatch())
74+
size(C, 1) == Aqm || throw(DimensionMismatch())
75+
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
76+
end
77+
78+
_AqmulB_specialscale!(C::StridedVecOrMat, β::Number) =
79+
β == 1 ||== 0 ? fill!(C, zero(eltype(C))) : scale!(C, β))
80+
81+
function _AqmulB_kernel!::Number, A::SparseMatrixCSC, ::typeof(identity), B::Vector, C::Vector)
82+
for colA in 1:A.n
83+
αBforcolA = α * B[colA]
84+
@inbounds for indA in nzrange(A, colA)
85+
C[A.rowval[indA]] += A.nzval[indA] * αBforcolA
86+
end
87+
end
88+
end
89+
function _AqmulB_kernel!::Number, A::SparseMatrixCSC, ::typeof(identity), B::Matrix, C::Matrix)
90+
for colA in 1:A.n, colC in size(C, 2)
91+
αBforcolAcolC = α * B[colA, colC]
92+
@inbounds for indA in nzrange(A, colA)
93+
C[A.rowval[indA], colC] += A.nzval[indA] * αBforcolAcolC
8694
end
87-
function $(f){TA,S,Tx}(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Tx})
88-
T = promote_type(TA, Tx)
89-
$(Symbol(f,:!))(one(T), A, B, zero(T), similar(B, T, (A.n, size(B, 2))))
95+
end
96+
end
97+
function _AqmulB_kernel!::Number, A::SparseMatrixCSC, op::qtransposefntype, B::Vector, C::Vector)
98+
for colA in 1:A.n
99+
accumulator = zero(eltype(C))
100+
@inbounds for indA in nzrange(A, colA)
101+
accumulator += op(A.nzval[indA]) * B[A.rowval[indA]]
102+
end
103+
C[colA] += α * accumulator
104+
end
105+
end
106+
function _AqmulB_kernel!::Number, A::SparseMatrixCSC, op::qtransposefntype, B::Matrix, C::Matrix)
107+
for colA in 1:A.n, colC in 1:size(C, 2)
108+
accumulator = zero(eltype(C))
109+
@inbounds for indA in nzrange(A, colA)
110+
accumulator += op(A.nzval[indA]) * B[A.rowval[indA], colC]
90111
end
112+
C[colA, colC] += α * accumulator
91113
end
92114
end
93115

116+
94117
# For compatibility with dense multiplication API. Should be deleted when dense multiplication
95118
# API is updated to follow BLAS API.
96119
A_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = A_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C)

0 commit comments

Comments
 (0)