Skip to content

Commit 0af99e6

Browse files
authored
Call MulAddMul instead of multiplication in _generic_matmatmul! (#56089)
Fix https://github.com/JuliaLang/julia/issues/56085 by calling a newly created `MulAddMul` object that only wraps the `alpha` (with `beta` set to `false`). This avoids the explicit multiplication if `alpha` is known to be `isone`.
1 parent d749f0e commit 0af99e6

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

stdlib/LinearAlgebra/src/matmul.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
919919
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
920920

921921
@noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S},
922-
_add::MulAddMul) where {T,S,R}
922+
_add::MulAddMul{ais1}) where {T,S,R,ais1}
923923
AxM = axes(A, 1)
924924
AxK = axes(A, 2) # we use two `axes` calls in case of `AbstractVector`
925925
BxK = axes(B, 1)
@@ -935,11 +935,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::A
935935
if BxN != CxN
936936
throw(DimensionMismatch(lazy"matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)"))
937937
end
938+
_rmul_alpha = MulAddMul{ais1,true,typeof(_add.alpha),Bool}(_add.alpha,false)
938939
if isbitstype(R) && sizeof(R) 16 && !(A isa Adjoint || A isa Transpose)
939940
_rmul_or_fill!(C, _add.beta)
940941
(iszero(_add.alpha) || isempty(A) || isempty(B)) && return C
941942
@inbounds for n in BxN, k in BxK
942-
Balpha = B[k,n]*_add.alpha
943+
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
944+
Balpha = _rmul_alpha(B[k,n])
943945
@simd for m in AxM
944946
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
945947
end

stdlib/LinearAlgebra/test/matmul.jl

+18
Original file line numberDiff line numberDiff line change
@@ -1130,4 +1130,22 @@ end
11301130
@test a * transpose(B) A * transpose(B)
11311131
end
11321132

1133+
@testset "issue #56085" begin
1134+
struct Thing
1135+
data::Float64
1136+
end
1137+
1138+
Base.zero(::Type{Thing}) = Thing(0.)
1139+
Base.zero(::Thing) = Thing(0.)
1140+
Base.one(::Type{Thing}) = Thing(1.)
1141+
Base.one(::Thing) = Thing(1.)
1142+
Base.:+(t::Thing...) = +(getfield.(t, :data)...)
1143+
Base.:*(t::Thing...) = *(getfield.(t, :data)...)
1144+
1145+
M = Float64[1 2; 3 4]
1146+
A = Thing.(M)
1147+
1148+
@test A * A M * M
1149+
end
1150+
11331151
end # module TestMatmul

0 commit comments

Comments
 (0)