@@ -930,164 +930,143 @@ end
930
930
931
931
932
932
# multiply 2x2 matrices
933
- function matmul2x2 (tA, tB, A:: AbstractMatrix{T} , B:: AbstractMatrix{S} ) where {T,S}
933
+ Base . @constprop :aggressive function matmul2x2 (tA, tB, A:: AbstractMatrix{T} , B:: AbstractMatrix{S} ) where {T,S}
934
934
matmul2x2! (similar (B, promote_op (matprod, T, S), 2 , 2 ), tA, tB, A, B)
935
935
end
936
936
937
- function matmul2x2! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
938
- _add:: MulAddMul = MulAddMul ())
937
+ function __matmul_checks (C, A, B, sz)
939
938
require_one_based_indexing (C, A, B)
940
939
if C === A || B === C
941
940
throw (ArgumentError (" output matrix must not be aliased with input matrix" ))
942
941
end
943
- if ! (size (A) == size (B) == size (C) == ( 2 , 2 ) )
942
+ if ! (size (A) == size (B) == size (C) == sz )
944
943
throw (DimensionMismatch (lazy " A has size $(size(A)), B has size $(size(B)), C has size $(size(C))" ))
945
944
end
945
+ return nothing
946
+ end
947
+
948
+ # separate function with the core of matmul2x2! that doesn't depend on a MulAddMul
949
+ Base. @constprop :aggressive function _matmul2x2_elements (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix )
950
+ __matmul_checks (C, A, B, (2 ,2 ))
951
+ __matmul2x2_elements (tA, tB, A, B)
952
+ end
953
+ Base. @constprop :aggressive function __matmul2x2_elements (tA, A:: AbstractMatrix )
946
954
@inbounds begin
947
- if tA == ' N'
955
+ tA_uc = uppercase (tA) # possibly unwrap a WrapperChar
956
+ if tA_uc == ' N'
948
957
A11 = A[1 ,1 ]; A12 = A[1 ,2 ]; A21 = A[2 ,1 ]; A22 = A[2 ,2 ]
949
- elseif tA == ' T'
958
+ elseif tA_uc == ' T'
950
959
# TODO making these lazy could improve perf
951
960
A11 = copy (transpose (A[1 ,1 ])); A12 = copy (transpose (A[2 ,1 ]))
952
961
A21 = copy (transpose (A[1 ,2 ])); A22 = copy (transpose (A[2 ,2 ]))
953
- elseif tA == ' C'
962
+ elseif tA_uc == ' C'
954
963
# TODO making these lazy could improve perf
955
964
A11 = copy (A[1 ,1 ]' ); A12 = copy (A[2 ,1 ]' )
956
965
A21 = copy (A[1 ,2 ]' ); A22 = copy (A[2 ,2 ]' )
957
- elseif tA == ' S'
958
- A11 = symmetric (A[1 ,1 ], :U ); A12 = A[1 ,2 ]
959
- A21 = copy (transpose (A[1 ,2 ])); A22 = symmetric (A[2 ,2 ], :U )
960
- elseif tA == ' s'
961
- A11 = symmetric (A[1 ,1 ], :L ); A12 = copy (transpose (A[2 ,1 ]))
962
- A21 = A[2 ,1 ]; A22 = symmetric (A[2 ,2 ], :L )
963
- elseif tA == ' H'
964
- A11 = hermitian (A[1 ,1 ], :U ); A12 = A[1 ,2 ]
965
- A21 = copy (adjoint (A[1 ,2 ])); A22 = hermitian (A[2 ,2 ], :U )
966
- else # if tA == 'h'
967
- A11 = hermitian (A[1 ,1 ], :L ); A12 = copy (adjoint (A[2 ,1 ]))
968
- A21 = A[2 ,1 ]; A22 = hermitian (A[2 ,2 ], :L )
969
- end
970
- if tB == ' N'
971
- B11 = B[1 ,1 ]; B12 = B[1 ,2 ];
972
- B21 = B[2 ,1 ]; B22 = B[2 ,2 ]
973
- elseif tB == ' T'
974
- # TODO making these lazy could improve perf
975
- B11 = copy (transpose (B[1 ,1 ])); B12 = copy (transpose (B[2 ,1 ]))
976
- B21 = copy (transpose (B[1 ,2 ])); B22 = copy (transpose (B[2 ,2 ]))
977
- elseif tB == ' C'
978
- # TODO making these lazy could improve perf
979
- B11 = copy (B[1 ,1 ]' ); B12 = copy (B[2 ,1 ]' )
980
- B21 = copy (B[1 ,2 ]' ); B22 = copy (B[2 ,2 ]' )
981
- elseif tB == ' S'
982
- B11 = symmetric (B[1 ,1 ], :U ); B12 = B[1 ,2 ]
983
- B21 = copy (transpose (B[1 ,2 ])); B22 = symmetric (B[2 ,2 ], :U )
984
- elseif tB == ' s'
985
- B11 = symmetric (B[1 ,1 ], :L ); B12 = copy (transpose (B[2 ,1 ]))
986
- B21 = B[2 ,1 ]; B22 = symmetric (B[2 ,2 ], :L )
987
- elseif tB == ' H'
988
- B11 = hermitian (B[1 ,1 ], :U ); B12 = B[1 ,2 ]
989
- B21 = copy (adjoint (B[1 ,2 ])); B22 = hermitian (B[2 ,2 ], :U )
990
- else # if tB == 'h'
991
- B11 = hermitian (B[1 ,1 ], :L ); B12 = copy (adjoint (B[2 ,1 ]))
992
- B21 = B[2 ,1 ]; B22 = hermitian (B[2 ,2 ], :L )
966
+ elseif tA_uc == ' S'
967
+ if isuppercase (tA) # tA == 'S'
968
+ A11 = symmetric (A[1 ,1 ], :U ); A12 = A[1 ,2 ]
969
+ A21 = copy (transpose (A[1 ,2 ])); A22 = symmetric (A[2 ,2 ], :U )
970
+ else
971
+ A11 = symmetric (A[1 ,1 ], :L ); A12 = copy (transpose (A[2 ,1 ]))
972
+ A21 = A[2 ,1 ]; A22 = symmetric (A[2 ,2 ], :L )
973
+ end
974
+ elseif tA_uc == ' H'
975
+ if isuppercase (tA) # tA == 'H'
976
+ A11 = hermitian (A[1 ,1 ], :U ); A12 = A[1 ,2 ]
977
+ A21 = copy (adjoint (A[1 ,2 ])); A22 = hermitian (A[2 ,2 ], :U )
978
+ else # if tA == 'h'
979
+ A11 = hermitian (A[1 ,1 ], :L ); A12 = copy (adjoint (A[2 ,1 ]))
980
+ A21 = A[2 ,1 ]; A22 = hermitian (A[2 ,2 ], :L )
981
+ end
993
982
end
983
+ end # inbounds
984
+ A11, A12, A21, A22
985
+ end
986
+ Base. @constprop :aggressive __matmul2x2_elements (tA, tB, A, B) = __matmul2x2_elements (tA, A), __matmul2x2_elements (tB, B)
987
+
988
+ Base. @constprop :aggressive function matmul2x2! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
989
+ _add:: MulAddMul = MulAddMul ())
990
+ (A11, A12, A21, A22), (B11, B12, B21, B22) = _matmul2x2_elements (C, tA, tB, A, B)
991
+ @inbounds begin
994
992
_modify! (_add, A11* B11 + A12* B21, C, (1 ,1 ))
995
- _modify! (_add, A11* B12 + A12* B22, C, (1 ,2 ))
996
993
_modify! (_add, A21* B11 + A22* B21, C, (2 ,1 ))
994
+ _modify! (_add, A11* B12 + A12* B22, C, (1 ,2 ))
997
995
_modify! (_add, A21* B12 + A22* B22, C, (2 ,2 ))
998
996
end # inbounds
999
997
C
1000
998
end
1001
999
1002
1000
# Multiply 3x3 matrices
1003
- function matmul3x3 (tA, tB, A:: AbstractMatrix{T} , B:: AbstractMatrix{S} ) where {T,S}
1001
+ Base . @constprop :aggressive function matmul3x3 (tA, tB, A:: AbstractMatrix{T} , B:: AbstractMatrix{S} ) where {T,S}
1004
1002
matmul3x3! (similar (B, promote_op (matprod, T, S), 3 , 3 ), tA, tB, A, B)
1005
1003
end
1006
1004
1007
- function matmul3x3! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
1008
- _add:: MulAddMul = MulAddMul ())
1009
- require_one_based_indexing (C, A, B)
1010
- if C === A || B === C
1011
- throw (ArgumentError (" output matrix must not be aliased with input matrix" ))
1012
- end
1013
- if ! (size (A) == size (B) == size (C) == (3 ,3 ))
1014
- throw (DimensionMismatch (lazy " A has size $(size(A)), B has size $(size(B)), C has size $(size(C))" ))
1015
- end
1005
+ # separate function with the core of matmul3x3! that doesn't depend on a MulAddMul
1006
+ Base. @constprop :aggressive function _matmul3x3_elements (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix )
1007
+ __matmul_checks (C, A, B, (3 ,3 ))
1008
+ __matmul3x3_elements (tA, tB, A, B)
1009
+ end
1010
+ Base. @constprop :aggressive function __matmul3x3_elements (tA, A:: AbstractMatrix )
1016
1011
@inbounds begin
1017
- if tA == ' N'
1012
+ tA_uc = uppercase (tA) # possibly unwrap a WrapperChar
1013
+ if tA_uc == ' N'
1018
1014
A11 = A[1 ,1 ]; A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
1019
1015
A21 = A[2 ,1 ]; A22 = A[2 ,2 ]; A23 = A[2 ,3 ]
1020
1016
A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = A[3 ,3 ]
1021
- elseif tA == ' T'
1017
+ elseif tA_uc == ' T'
1022
1018
# TODO making these lazy could improve perf
1023
1019
A11 = copy (transpose (A[1 ,1 ])); A12 = copy (transpose (A[2 ,1 ])); A13 = copy (transpose (A[3 ,1 ]))
1024
1020
A21 = copy (transpose (A[1 ,2 ])); A22 = copy (transpose (A[2 ,2 ])); A23 = copy (transpose (A[3 ,2 ]))
1025
1021
A31 = copy (transpose (A[1 ,3 ])); A32 = copy (transpose (A[2 ,3 ])); A33 = copy (transpose (A[3 ,3 ]))
1026
- elseif tA == ' C'
1022
+ elseif tA_uc == ' C'
1027
1023
# TODO making these lazy could improve perf
1028
1024
A11 = copy (A[1 ,1 ]' ); A12 = copy (A[2 ,1 ]' ); A13 = copy (A[3 ,1 ]' )
1029
1025
A21 = copy (A[1 ,2 ]' ); A22 = copy (A[2 ,2 ]' ); A23 = copy (A[3 ,2 ]' )
1030
1026
A31 = copy (A[1 ,3 ]' ); A32 = copy (A[2 ,3 ]' ); A33 = copy (A[3 ,3 ]' )
1031
- elseif tA == ' S'
1032
- A11 = symmetric (A[1 ,1 ], :U ); A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
1033
- A21 = copy (transpose (A[1 ,2 ])); A22 = symmetric (A[2 ,2 ], :U ); A23 = A[2 ,3 ]
1034
- A31 = copy (transpose (A[1 ,3 ])); A32 = copy (transpose (A[2 ,3 ])); A33 = symmetric (A[3 ,3 ], :U )
1035
- elseif tA == ' s'
1036
- A11 = symmetric (A[1 ,1 ], :L ); A12 = copy (transpose (A[2 ,1 ])); A13 = copy (transpose (A[3 ,1 ]))
1037
- A21 = A[2 ,1 ]; A22 = symmetric (A[2 ,2 ], :L ); A23 = copy (transpose (A[3 ,2 ]))
1038
- A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = symmetric (A[3 ,3 ], :L )
1039
- elseif tA == ' H'
1040
- A11 = hermitian (A[1 ,1 ], :U ); A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
1041
- A21 = copy (adjoint (A[1 ,2 ])); A22 = hermitian (A[2 ,2 ], :U ); A23 = A[2 ,3 ]
1042
- A31 = copy (adjoint (A[1 ,3 ])); A32 = copy (adjoint (A[2 ,3 ])); A33 = hermitian (A[3 ,3 ], :U )
1043
- else # if tA == 'h'
1044
- A11 = hermitian (A[1 ,1 ], :L ); A12 = copy (adjoint (A[2 ,1 ])); A13 = copy (adjoint (A[3 ,1 ]))
1045
- A21 = A[2 ,1 ]; A22 = hermitian (A[2 ,2 ], :L ); A23 = copy (adjoint (A[3 ,2 ]))
1046
- A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = hermitian (A[3 ,3 ], :L )
1027
+ elseif tA_uc == ' S'
1028
+ if isuppercase (tA) # tA == 'S'
1029
+ A11 = symmetric (A[1 ,1 ], :U ); A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
1030
+ A21 = copy (transpose (A[1 ,2 ])); A22 = symmetric (A[2 ,2 ], :U ); A23 = A[2 ,3 ]
1031
+ A31 = copy (transpose (A[1 ,3 ])); A32 = copy (transpose (A[2 ,3 ])); A33 = symmetric (A[3 ,3 ], :U )
1032
+ else
1033
+ A11 = symmetric (A[1 ,1 ], :L ); A12 = copy (transpose (A[2 ,1 ])); A13 = copy (transpose (A[3 ,1 ]))
1034
+ A21 = A[2 ,1 ]; A22 = symmetric (A[2 ,2 ], :L ); A23 = copy (transpose (A[3 ,2 ]))
1035
+ A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = symmetric (A[3 ,3 ], :L )
1036
+ end
1037
+ elseif tA_uc == ' H'
1038
+ if isuppercase (tA) # tA == 'H'
1039
+ A11 = hermitian (A[1 ,1 ], :U ); A12 = A[1 ,2 ]; A13 = A[1 ,3 ]
1040
+ A21 = copy (adjoint (A[1 ,2 ])); A22 = hermitian (A[2 ,2 ], :U ); A23 = A[2 ,3 ]
1041
+ A31 = copy (adjoint (A[1 ,3 ])); A32 = copy (adjoint (A[2 ,3 ])); A33 = hermitian (A[3 ,3 ], :U )
1042
+ else # if tA == 'h'
1043
+ A11 = hermitian (A[1 ,1 ], :L ); A12 = copy (adjoint (A[2 ,1 ])); A13 = copy (adjoint (A[3 ,1 ]))
1044
+ A21 = A[2 ,1 ]; A22 = hermitian (A[2 ,2 ], :L ); A23 = copy (adjoint (A[3 ,2 ]))
1045
+ A31 = A[3 ,1 ]; A32 = A[3 ,2 ]; A33 = hermitian (A[3 ,3 ], :L )
1046
+ end
1047
1047
end
1048
+ end # inbounds
1049
+ A11, A12, A13, A21, A22, A23, A31, A32, A33
1050
+ end
1051
+ Base. @constprop :aggressive __matmul3x3_elements (tA, tB, A, B) = __matmul3x3_elements (tA, A), __matmul3x3_elements (tB, B)
1048
1052
1049
- if tB == ' N'
1050
- B11 = B[1 ,1 ]; B12 = B[1 ,2 ]; B13 = B[1 ,3 ]
1051
- B21 = B[2 ,1 ]; B22 = B[2 ,2 ]; B23 = B[2 ,3 ]
1052
- B31 = B[3 ,1 ]; B32 = B[3 ,2 ]; B33 = B[3 ,3 ]
1053
- elseif tB == ' T'
1054
- # TODO making these lazy could improve perf
1055
- B11 = copy (transpose (B[1 ,1 ])); B12 = copy (transpose (B[2 ,1 ])); B13 = copy (transpose (B[3 ,1 ]))
1056
- B21 = copy (transpose (B[1 ,2 ])); B22 = copy (transpose (B[2 ,2 ])); B23 = copy (transpose (B[3 ,2 ]))
1057
- B31 = copy (transpose (B[1 ,3 ])); B32 = copy (transpose (B[2 ,3 ])); B33 = copy (transpose (B[3 ,3 ]))
1058
- elseif tB == ' C'
1059
- # TODO making these lazy could improve perf
1060
- B11 = copy (B[1 ,1 ]' ); B12 = copy (B[2 ,1 ]' ); B13 = copy (B[3 ,1 ]' )
1061
- B21 = copy (B[1 ,2 ]' ); B22 = copy (B[2 ,2 ]' ); B23 = copy (B[3 ,2 ]' )
1062
- B31 = copy (B[1 ,3 ]' ); B32 = copy (B[2 ,3 ]' ); B33 = copy (B[3 ,3 ]' )
1063
- elseif tB == ' S'
1064
- B11 = symmetric (B[1 ,1 ], :U ); B12 = B[1 ,2 ]; B13 = B[1 ,3 ]
1065
- B21 = copy (transpose (B[1 ,2 ])); B22 = symmetric (B[2 ,2 ], :U ); B23 = B[2 ,3 ]
1066
- B31 = copy (transpose (B[1 ,3 ])); B32 = copy (transpose (B[2 ,3 ])); B33 = symmetric (B[3 ,3 ], :U )
1067
- elseif tB == ' s'
1068
- B11 = symmetric (B[1 ,1 ], :L ); B12 = copy (transpose (B[2 ,1 ])); B13 = copy (transpose (B[3 ,1 ]))
1069
- B21 = B[2 ,1 ]; B22 = symmetric (B[2 ,2 ], :L ); B23 = copy (transpose (B[3 ,2 ]))
1070
- B31 = B[3 ,1 ]; B32 = B[3 ,2 ]; B33 = symmetric (B[3 ,3 ], :L )
1071
- elseif tB == ' H'
1072
- B11 = hermitian (B[1 ,1 ], :U ); B12 = B[1 ,2 ]; B13 = B[1 ,3 ]
1073
- B21 = copy (adjoint (B[1 ,2 ])); B22 = hermitian (B[2 ,2 ], :U ); B23 = B[2 ,3 ]
1074
- B31 = copy (adjoint (B[1 ,3 ])); B32 = copy (adjoint (B[2 ,3 ])); B33 = hermitian (B[3 ,3 ], :U )
1075
- else # if tB == 'h'
1076
- B11 = hermitian (B[1 ,1 ], :L ); B12 = copy (adjoint (B[2 ,1 ])); B13 = copy (adjoint (B[3 ,1 ]))
1077
- B21 = B[2 ,1 ]; B22 = hermitian (B[2 ,2 ], :L ); B23 = copy (adjoint (B[3 ,2 ]))
1078
- B31 = B[3 ,1 ]; B32 = B[3 ,2 ]; B33 = hermitian (B[3 ,3 ], :L )
1079
- end
1053
+ Base. @constprop :aggressive function matmul3x3! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
1054
+ _add:: MulAddMul = MulAddMul ())
1080
1055
1081
- _modify! (_add, A11* B11 + A12* B21 + A13* B31, C, (1 ,1 ))
1082
- _modify! (_add, A11* B12 + A12* B22 + A13* B32, C, (1 ,2 ))
1083
- _modify! (_add, A11* B13 + A12* B23 + A13* B33, C, (1 ,3 ))
1056
+ (A11, A12, A13, A21, A22, A23, A31, A32, A33),
1057
+ (B11, B12, B13, B21, B22, B23, B31, B32, B33) = _matmul3x3_elements (C, tA, tB, A, B)
1084
1058
1059
+ @inbounds begin
1060
+ _modify! (_add, A11* B11 + A12* B21 + A13* B31, C, (1 ,1 ))
1085
1061
_modify! (_add, A21* B11 + A22* B21 + A23* B31, C, (2 ,1 ))
1086
- _modify! (_add, A21* B12 + A22* B22 + A23* B32, C, (2 ,2 ))
1087
- _modify! (_add, A21* B13 + A22* B23 + A23* B33, C, (2 ,3 ))
1088
-
1089
1062
_modify! (_add, A31* B11 + A32* B21 + A33* B31, C, (3 ,1 ))
1063
+
1064
+ _modify! (_add, A11* B12 + A12* B22 + A13* B32, C, (1 ,2 ))
1065
+ _modify! (_add, A21* B12 + A22* B22 + A23* B32, C, (2 ,2 ))
1090
1066
_modify! (_add, A31* B12 + A32* B22 + A33* B32, C, (3 ,2 ))
1067
+
1068
+ _modify! (_add, A11* B13 + A12* B23 + A13* B33, C, (1 ,3 ))
1069
+ _modify! (_add, A21* B13 + A22* B23 + A23* B33, C, (2 ,3 ))
1091
1070
_modify! (_add, A31* B13 + A32* B23 + A33* B33, C, (3 ,3 ))
1092
1071
end # inbounds
1093
1072
C
0 commit comments