@@ -919,7 +919,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
919
919
_generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul (α, β))
920
920
921
921
@noinline function _generic_matmatmul! (C:: AbstractVecOrMat{R} , A:: AbstractVecOrMat{T} , B:: AbstractVecOrMat{S} ,
922
- _add:: MulAddMul ) where {T,S,R}
922
+ _add:: MulAddMul{ais1} ) where {T,S,R,ais1 }
923
923
AxM = axes (A, 1 )
924
924
AxK = axes (A, 2 ) # we use two `axes` calls in case of `AbstractVector`
925
925
BxK = axes (B, 1 )
@@ -935,11 +935,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
935
935
if BxN != CxN
936
936
throw (DimensionMismatch (lazy " matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)" ))
937
937
end
938
+ _rmul_alpha = MulAddMul {ais1,true,typeof(_add.alpha),Bool} (_add. alpha,false )
938
939
if isbitstype (R) && sizeof (R) ≤ 16 && ! (A isa Adjoint || A isa Transpose)
939
940
_rmul_or_fill! (C, _add. beta)
940
941
(iszero (_add. alpha) || isempty (A) || isempty (B)) && return C
941
942
@inbounds for n in BxN, k in BxK
942
- Balpha = B[k,n]* _add. alpha
943
+ # Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
944
+ Balpha = _rmul_alpha (B[k,n])
943
945
@simd for m in AxM
944
946
C[m,n] = muladd (A[m,k], Balpha, C[m,n])
945
947
end
0 commit comments