Skip to content

Commit 5006312

Browse files
authored
Reduce matmul latency by splitting small matmul (#54421)
This splits the `matmul2x2` and `matmul3x3` into components that depend on `MulAddMul` and those that don't depend on it. This improves compilation time, as the `MulAddMul`-independent methods won't need to be recompiled in the `@stable_muladdmul` branches. TTFX (each call timed in a separate session): ```julia julia> using LinearAlgebra julia> A = rand(2,2); B = Symmetric(rand(2,2)); C = zeros(2,2); julia> @time mul!(C, A, B); 1.927468 seconds (5.67 M allocations: 282.523 MiB, 12.09% gc time, 100.00% compilation time) # nightly v"1.12.0-DEV.492" 1.282717 seconds (4.46 M allocations: 228.816 MiB, 4.58% gc time, 100.00% compilation time) # This PR julia> A = rand(2,2); B = rand(2,2); C = zeros(2,2); julia> @time mul!(C, A, B); 1.653368 seconds (5.75 M allocations: 291.586 MiB, 13.94% gc time, 100.00% compilation time) # nightly 1.148330 seconds (4.46 M allocations: 230.714 MiB, 4.47% gc time, 100.00% compilation time) # This PR ``` Edit: Not inlining the function seems to incur a runtime perfomance cost. ```julia julia> using LinearAlgebra julia> A = rand(3,3); B = rand(size(A)...); C = zeros(size(A)); julia> @Btime mul!($C, $A, $B); 23.923 ns (0 allocations: 0 bytes) # nightly 31.732 ns (0 allocations: 0 bytes) # This PR ``` Adding `@inline` annotations resolves this difference, but this reintroduces the compilation latency. The tradeoff is perhaps ok, as users may use `StaticArrays` for performance-critical matrix multiplications.
1 parent 25c8128 commit 5006312

File tree

1 file changed

+89
-110
lines changed

1 file changed

+89
-110
lines changed

stdlib/LinearAlgebra/src/matmul.jl

+89-110
Original file line numberDiff line numberDiff line change
@@ -930,164 +930,143 @@ end
930930

931931

932932
# multiply 2x2 matrices
933-
function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
933+
Base.@constprop :aggressive function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
934934
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
935935
end
936936

937-
function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
938-
_add::MulAddMul = MulAddMul())
937+
function __matmul_checks(C, A, B, sz)
939938
require_one_based_indexing(C, A, B)
940939
if C === A || B === C
941940
throw(ArgumentError("output matrix must not be aliased with input matrix"))
942941
end
943-
if !(size(A) == size(B) == size(C) == (2,2))
942+
if !(size(A) == size(B) == size(C) == sz)
944943
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
945944
end
945+
return nothing
946+
end
947+
948+
# separate function with the core of matmul2x2! that doesn't depend on a MulAddMul
949+
Base.@constprop :aggressive function _matmul2x2_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix)
950+
__matmul_checks(C, A, B, (2,2))
951+
__matmul2x2_elements(tA, tB, A, B)
952+
end
953+
Base.@constprop :aggressive function __matmul2x2_elements(tA, A::AbstractMatrix)
946954
@inbounds begin
947-
if tA == 'N'
955+
tA_uc = uppercase(tA) # possibly unwrap a WrapperChar
956+
if tA_uc == 'N'
948957
A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2]
949-
elseif tA == 'T'
958+
elseif tA_uc == 'T'
950959
# TODO making these lazy could improve perf
951960
A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1]))
952961
A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2]))
953-
elseif tA == 'C'
962+
elseif tA_uc == 'C'
954963
# TODO making these lazy could improve perf
955964
A11 = copy(A[1,1]'); A12 = copy(A[2,1]')
956965
A21 = copy(A[1,2]'); A22 = copy(A[2,2]')
957-
elseif tA == 'S'
958-
A11 = symmetric(A[1,1], :U); A12 = A[1,2]
959-
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U)
960-
elseif tA == 's'
961-
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1]))
962-
A21 = A[2,1]; A22 = symmetric(A[2,2], :L)
963-
elseif tA == 'H'
964-
A11 = hermitian(A[1,1], :U); A12 = A[1,2]
965-
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U)
966-
else # if tA == 'h'
967-
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1]))
968-
A21 = A[2,1]; A22 = hermitian(A[2,2], :L)
969-
end
970-
if tB == 'N'
971-
B11 = B[1,1]; B12 = B[1,2];
972-
B21 = B[2,1]; B22 = B[2,2]
973-
elseif tB == 'T'
974-
# TODO making these lazy could improve perf
975-
B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1]))
976-
B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2]))
977-
elseif tB == 'C'
978-
# TODO making these lazy could improve perf
979-
B11 = copy(B[1,1]'); B12 = copy(B[2,1]')
980-
B21 = copy(B[1,2]'); B22 = copy(B[2,2]')
981-
elseif tB == 'S'
982-
B11 = symmetric(B[1,1], :U); B12 = B[1,2]
983-
B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U)
984-
elseif tB == 's'
985-
B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1]))
986-
B21 = B[2,1]; B22 = symmetric(B[2,2], :L)
987-
elseif tB == 'H'
988-
B11 = hermitian(B[1,1], :U); B12 = B[1,2]
989-
B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U)
990-
else # if tB == 'h'
991-
B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1]))
992-
B21 = B[2,1]; B22 = hermitian(B[2,2], :L)
966+
elseif tA_uc == 'S'
967+
if isuppercase(tA) # tA == 'S'
968+
A11 = symmetric(A[1,1], :U); A12 = A[1,2]
969+
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U)
970+
else
971+
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1]))
972+
A21 = A[2,1]; A22 = symmetric(A[2,2], :L)
973+
end
974+
elseif tA_uc == 'H'
975+
if isuppercase(tA) # tA == 'H'
976+
A11 = hermitian(A[1,1], :U); A12 = A[1,2]
977+
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U)
978+
else # if tA == 'h'
979+
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1]))
980+
A21 = A[2,1]; A22 = hermitian(A[2,2], :L)
981+
end
993982
end
983+
end # inbounds
984+
A11, A12, A21, A22
985+
end
986+
Base.@constprop :aggressive __matmul2x2_elements(tA, tB, A, B) = __matmul2x2_elements(tA, A), __matmul2x2_elements(tB, B)
987+
988+
Base.@constprop :aggressive function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
989+
_add::MulAddMul = MulAddMul())
990+
(A11, A12, A21, A22), (B11, B12, B21, B22) = _matmul2x2_elements(C, tA, tB, A, B)
991+
@inbounds begin
994992
_modify!(_add, A11*B11 + A12*B21, C, (1,1))
995-
_modify!(_add, A11*B12 + A12*B22, C, (1,2))
996993
_modify!(_add, A21*B11 + A22*B21, C, (2,1))
994+
_modify!(_add, A11*B12 + A12*B22, C, (1,2))
997995
_modify!(_add, A21*B12 + A22*B22, C, (2,2))
998996
end # inbounds
999997
C
1000998
end
1001999

10021000
# Multiply 3x3 matrices
1003-
function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
1001+
Base.@constprop :aggressive function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,S}
10041002
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
10051003
end
10061004

1007-
function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
1008-
_add::MulAddMul = MulAddMul())
1009-
require_one_based_indexing(C, A, B)
1010-
if C === A || B === C
1011-
throw(ArgumentError("output matrix must not be aliased with input matrix"))
1012-
end
1013-
if !(size(A) == size(B) == size(C) == (3,3))
1014-
throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
1015-
end
1005+
# separate function with the core of matmul3x3! that doesn't depend on a MulAddMul
1006+
Base.@constprop :aggressive function _matmul3x3_elements(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix)
1007+
__matmul_checks(C, A, B, (3,3))
1008+
__matmul3x3_elements(tA, tB, A, B)
1009+
end
1010+
Base.@constprop :aggressive function __matmul3x3_elements(tA, A::AbstractMatrix)
10161011
@inbounds begin
1017-
if tA == 'N'
1012+
tA_uc = uppercase(tA) # possibly unwrap a WrapperChar
1013+
if tA_uc == 'N'
10181014
A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3]
10191015
A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3]
10201016
A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3]
1021-
elseif tA == 'T'
1017+
elseif tA_uc == 'T'
10221018
# TODO making these lazy could improve perf
10231019
A11 = copy(transpose(A[1,1])); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
10241020
A21 = copy(transpose(A[1,2])); A22 = copy(transpose(A[2,2])); A23 = copy(transpose(A[3,2]))
10251021
A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = copy(transpose(A[3,3]))
1026-
elseif tA == 'C'
1022+
elseif tA_uc == 'C'
10271023
# TODO making these lazy could improve perf
10281024
A11 = copy(A[1,1]'); A12 = copy(A[2,1]'); A13 = copy(A[3,1]')
10291025
A21 = copy(A[1,2]'); A22 = copy(A[2,2]'); A23 = copy(A[3,2]')
10301026
A31 = copy(A[1,3]'); A32 = copy(A[2,3]'); A33 = copy(A[3,3]')
1031-
elseif tA == 'S'
1032-
A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
1033-
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3]
1034-
A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U)
1035-
elseif tA == 's'
1036-
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
1037-
A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2]))
1038-
A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L)
1039-
elseif tA == 'H'
1040-
A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
1041-
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3]
1042-
A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U)
1043-
else # if tA == 'h'
1044-
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1]))
1045-
A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2]))
1046-
A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L)
1027+
elseif tA_uc == 'S'
1028+
if isuppercase(tA) # tA == 'S'
1029+
A11 = symmetric(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
1030+
A21 = copy(transpose(A[1,2])); A22 = symmetric(A[2,2], :U); A23 = A[2,3]
1031+
A31 = copy(transpose(A[1,3])); A32 = copy(transpose(A[2,3])); A33 = symmetric(A[3,3], :U)
1032+
else
1033+
A11 = symmetric(A[1,1], :L); A12 = copy(transpose(A[2,1])); A13 = copy(transpose(A[3,1]))
1034+
A21 = A[2,1]; A22 = symmetric(A[2,2], :L); A23 = copy(transpose(A[3,2]))
1035+
A31 = A[3,1]; A32 = A[3,2]; A33 = symmetric(A[3,3], :L)
1036+
end
1037+
elseif tA_uc == 'H'
1038+
if isuppercase(tA) # tA == 'H'
1039+
A11 = hermitian(A[1,1], :U); A12 = A[1,2]; A13 = A[1,3]
1040+
A21 = copy(adjoint(A[1,2])); A22 = hermitian(A[2,2], :U); A23 = A[2,3]
1041+
A31 = copy(adjoint(A[1,3])); A32 = copy(adjoint(A[2,3])); A33 = hermitian(A[3,3], :U)
1042+
else # if tA == 'h'
1043+
A11 = hermitian(A[1,1], :L); A12 = copy(adjoint(A[2,1])); A13 = copy(adjoint(A[3,1]))
1044+
A21 = A[2,1]; A22 = hermitian(A[2,2], :L); A23 = copy(adjoint(A[3,2]))
1045+
A31 = A[3,1]; A32 = A[3,2]; A33 = hermitian(A[3,3], :L)
1046+
end
10471047
end
1048+
end # inbounds
1049+
A11, A12, A13, A21, A22, A23, A31, A32, A33
1050+
end
1051+
Base.@constprop :aggressive __matmul3x3_elements(tA, tB, A, B) = __matmul3x3_elements(tA, A), __matmul3x3_elements(tB, B)
10481052

1049-
if tB == 'N'
1050-
B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3]
1051-
B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3]
1052-
B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3]
1053-
elseif tB == 'T'
1054-
# TODO making these lazy could improve perf
1055-
B11 = copy(transpose(B[1,1])); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
1056-
B21 = copy(transpose(B[1,2])); B22 = copy(transpose(B[2,2])); B23 = copy(transpose(B[3,2]))
1057-
B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = copy(transpose(B[3,3]))
1058-
elseif tB == 'C'
1059-
# TODO making these lazy could improve perf
1060-
B11 = copy(B[1,1]'); B12 = copy(B[2,1]'); B13 = copy(B[3,1]')
1061-
B21 = copy(B[1,2]'); B22 = copy(B[2,2]'); B23 = copy(B[3,2]')
1062-
B31 = copy(B[1,3]'); B32 = copy(B[2,3]'); B33 = copy(B[3,3]')
1063-
elseif tB == 'S'
1064-
B11 = symmetric(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
1065-
B21 = copy(transpose(B[1,2])); B22 = symmetric(B[2,2], :U); B23 = B[2,3]
1066-
B31 = copy(transpose(B[1,3])); B32 = copy(transpose(B[2,3])); B33 = symmetric(B[3,3], :U)
1067-
elseif tB == 's'
1068-
B11 = symmetric(B[1,1], :L); B12 = copy(transpose(B[2,1])); B13 = copy(transpose(B[3,1]))
1069-
B21 = B[2,1]; B22 = symmetric(B[2,2], :L); B23 = copy(transpose(B[3,2]))
1070-
B31 = B[3,1]; B32 = B[3,2]; B33 = symmetric(B[3,3], :L)
1071-
elseif tB == 'H'
1072-
B11 = hermitian(B[1,1], :U); B12 = B[1,2]; B13 = B[1,3]
1073-
B21 = copy(adjoint(B[1,2])); B22 = hermitian(B[2,2], :U); B23 = B[2,3]
1074-
B31 = copy(adjoint(B[1,3])); B32 = copy(adjoint(B[2,3])); B33 = hermitian(B[3,3], :U)
1075-
else # if tB == 'h'
1076-
B11 = hermitian(B[1,1], :L); B12 = copy(adjoint(B[2,1])); B13 = copy(adjoint(B[3,1]))
1077-
B21 = B[2,1]; B22 = hermitian(B[2,2], :L); B23 = copy(adjoint(B[3,2]))
1078-
B31 = B[3,1]; B32 = B[3,2]; B33 = hermitian(B[3,3], :L)
1079-
end
1053+
Base.@constprop :aggressive function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix,
1054+
_add::MulAddMul = MulAddMul())
10801055

1081-
_modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
1082-
_modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2))
1083-
_modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3))
1056+
(A11, A12, A13, A21, A22, A23, A31, A32, A33),
1057+
(B11, B12, B13, B21, B22, B23, B31, B32, B33) = _matmul3x3_elements(C, tA, tB, A, B)
10841058

1059+
@inbounds begin
1060+
_modify!(_add, A11*B11 + A12*B21 + A13*B31, C, (1,1))
10851061
_modify!(_add, A21*B11 + A22*B21 + A23*B31, C, (2,1))
1086-
_modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2))
1087-
_modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3))
1088-
10891062
_modify!(_add, A31*B11 + A32*B21 + A33*B31, C, (3,1))
1063+
1064+
_modify!(_add, A11*B12 + A12*B22 + A13*B32, C, (1,2))
1065+
_modify!(_add, A21*B12 + A22*B22 + A23*B32, C, (2,2))
10901066
_modify!(_add, A31*B12 + A32*B22 + A33*B32, C, (3,2))
1067+
1068+
_modify!(_add, A11*B13 + A12*B23 + A13*B33, C, (1,3))
1069+
_modify!(_add, A21*B13 + A22*B23 + A23*B33, C, (2,3))
10911070
_modify!(_add, A31*B13 + A32*B23 + A33*B33, C, (3,3))
10921071
end # inbounds
10931072
C

0 commit comments

Comments
 (0)