Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add specialized methods for complex-real BLAS multiplication. #6235

Merged
merged 1 commit into from
Apr 10, 2014
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 135 additions & 106 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# matmul.jl: Everything to do with dense matrix multiplication

arithtype(T) = T
arithtype(::Type{Bool}) = Int

# multiply by diagonal matrix as vector
function scale!(C::Matrix, A::Matrix, b::Vector)
m, n = size(A)
Expand Down Expand Up @@ -56,88 +59,126 @@ At_mul_B{T<:Real}(x::Vector{T}, y::Vector{T}) = [dot(x, y)]
At_mul_B{T<:BlasComplex}(x::Vector{T}, y::Vector{T}) = [BLAS.dotu(x, y)]

# Matrix-vector multiplication

*{T<:Union(Float32,Integer,Rational)}(A::StridedMatrix{Float64}, X::StridedVector{T}) = A*convert(Vector{eltype(A)},X)
*{T<:Union(Float32,Complex64,Integer,Rational)}(A::StridedMatrix{Complex128}, X::StridedVector{T}) = A*convert(Vector{eltype(A)},X)
function *{T<:BlasFloat}(A::StridedMatrix{T}, X::StridedVector{T})
gemv(similar(A, size(A,1)), 'N', A, X)
function (*){T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_type(arithtype(T),arithtype(S))
A_mul_B!(similar(x, TS, size(A,1)), A, convert(AbstractVector{TS}, x))
end

A_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = gemv(y, 'N', A, x)
A_mul_B!(y::StridedVector, A::StridedMatrix, x::StridedVector) = generic_matvecmul(y, 'N', A, x)

At_mul_B{T<:Union(Float32,Integer,Rational)}(A::StridedMatrix{Float64}, X::StridedVector{T}) = At_mul_B(A,convert(Vector{eltype(A)},X))
At_mul_B{T<:Union(Float32,Complex64,Integer,Rational)}(A::StridedMatrix{Complex128}, X::StridedVector{T}) = At_mul_B(A,convert(Vector{eltype(A)},X))
function At_mul_B{T<:BlasFloat}(A::StridedMatrix{T}, x::StridedVector{T})
gemv(similar(A, size(A, 2)), 'T', A, x)
function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
TS = promote_type(arithtype(T),arithtype(S))
A_mul_B!(similar(x,TS,size(A,1)),A,x)
end
(*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B

A_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = gemv!(y, 'N', A, x)
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(y::StridedVector{Complex{$elty}}, A::StridedMatrix{Complex{$elty}}, x::StridedVector{$elty})
Afl = reinterpret($elty,A,(2size(A,1),size(A,2)))
yfl = reinterpret($elty,y)
gemv!(yfl,'N',Afl,x)
return y
end
end
end
A_mul_B!(y::StridedVector, A::StridedMatrix, x::StridedVector) = generic_matvecmul!(y, 'N', A, x)

At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = gemv(y, 'T', A, x)
At_mul_B!(y::StridedVector, A::StridedMatrix, x::StridedVector) = generic_matvecmul(y, 'T', A, x)
function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_type(arithtype(T),arithtype(S))
At_mul_B!(similar(x,TS,size(A,2)), A, convert(AbstractVector{TS}, x))
end
function At_mul_B{T,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_type(arithtype(T),arithtype(S))
At_mul_B!(similar(x,TS,size(A,2)), A, x)
end
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
At_mul_B!(y::StridedVector, A::StridedMatrix, x::StridedVector) = generic_matvecmul!(y, 'T', A, x)

Ac_mul_B{T<:Union(Float32,Integer,Rational)}(A::StridedMatrix{Float64}, X::StridedVector{T}) = Ac_mul_B(A,convert(Vector{eltype(A)},X))
Ac_mul_B{T<:Union(Float32,Complex64,Integer,Rational)}(A::StridedMatrix{Complex128}, X::StridedVector{T}) = Ac_mul_B(A,convert(Vector{eltype(A)},X))
Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, x::StridedVector{T}) = At_mul_B(A, x)
function Ac_mul_B{T<:BlasComplex}(A::StridedMatrix{T}, x::StridedVector{T})
gemv(similar(A, size(A, 2)), 'C', A, x)
function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_type(arithtype(T),arithtype(S))
Ac_mul_B!(similar(x,TS,size(A,2)),A,convert(AbstractVector{TS},x))
end
function Ac_mul_B{T,S}(A::StridedMatrix{T}, x::StridedVector{S})
TS = promote_type(arithtype(T),arithtype(S))
Ac_mul_B!(similar(x,TS,size(A,2)), A, x)
end

Ac_mul_B!{T<:BlasReal}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = At_mul_B!(y, A, x)
Ac_mul_B!{T<:BlasComplex}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = gemv(y, 'C', A, x)
Ac_mul_B!(y::StridedVector, A::StridedMatrix, x::StridedVector) = generic_matvecmul(y, 'C', A, x)

Ac_mul_B!{T<:BlasComplex}(y::StridedVector{T}, A::StridedMatrix{T}, x::StridedVector{T}) = gemv!(y, 'C', A, x)
Ac_mul_B!(y::StridedVector, A::StridedMatrix, x::StridedVector) = generic_matvecmul!(y, 'C', A, x)

# Matrix-matrix multiplication

(*){T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper('N', 'N', A, B)
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper(C, 'N', 'N', A, B)
A_mul_B!{T,S,R}(C::StridedMatrix{R}, A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul(C, 'N', 'N', A, B)

function At_mul_B{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T})
is(A, B) ? syrk_wrapper('T', A) : gemm_wrapper('T', 'N', A, B)
function (*){T,S}(A::StridedMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T),arithtype(S))
A_mul_B!(similar(B,TS,(size(A,1),size(B,2))),A,B)
end
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(C::StridedMatrix{Complex{$elty}}, A::StridedMatrix{Complex{$elty}}, B::StridedMatrix{$elty})
Afl = reinterpret($elty,A,(2size(A,1),size(A,2)))
Cfl = reinterpret($elty,C,(2size(C,1),size(C,2)))
gemm_wrapper!(Cfl,'N','N',Afl,B)
return C
end
end
end
A_mul_B!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'N', 'N', A, B)

At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedMatrix{T}) = gemm_wrapper(C, 'T', 'N', A, B)
At_mul_B{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul('T', 'N', A, B)
At_mul_B!{T,S,R}(C::StridedMatrix{R}, A::StridedVecOrMat{T}, B::StridedMatrix{S}) = generic_matmatmul(C, 'T', 'N', A, B)

function A_mul_Bt{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T})
is(A, B) ? syrk_wrapper('N', A) : gemm_wrapper('N', 'T', A, B)
function At_mul_B{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T),arithtype(S))
At_mul_B!(similar(B,TS,(size(A,2),size(B,2))),A,B)
end
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = is(A,B) ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_B!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'T', 'N', A, B)

A_mul_Bt!{T<:BlasFloat}(C::StridedVecOrMat{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper(C, 'N', 'T', A, B)
A_mul_Bt{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul('N', 'T', A, B)
A_mul_Bt!{T,S,R}(C::StridedVecOrMat{R}, A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul(C, 'N', 'T', A, B)
function A_mul_Bt{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T),arithtype(S))
A_mul_Bt!(similar(B,TS,(size(A,1),size(B,1))),A,B)
end
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = is(A,B) ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
for elty in (Float32,Float64)
@eval begin
function A_mul_Bt!(C::StridedMatrix{Complex{$elty}}, A::StridedMatrix{Complex{$elty}}, B::StridedMatrix{$elty})
Afl = reinterpret($elty,A,(2size(A,1),size(A,2)))
Cfl = reinterpret($elty,C,(2size(C,1),size(C,2)))
gemm_wrapper!(Cfl,'N','T',Afl,B)
return C
end
end
end
A_mul_Bt!(C::StridedVecOrMat, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'N', 'T', A, B)

At_mul_Bt{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper('T', 'T', A, B)
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper(C, 'T', 'T', A, B)
At_mul_Bt{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul('T', 'T', A, B)
At_mul_Bt!{T,S,R}(C::StridedMatrix{R}, A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul(C, 'T', 'T', A, B)
function At_mul_Bt{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T),arithtype(S))
At_mul_Bt!(similar(B,TS,(size(A,2),size(B,1))),A,B)
end
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
At_mul_Bt!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'T', 'T', A, B)

Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B)
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B!(C, A, B)

function Ac_mul_B{T<:BlasComplex}(A::StridedMatrix{T}, B::StridedMatrix{T})
is(A, B) ? herk_wrapper('C', A) : gemm_wrapper('C', 'N', A, B)
function Ac_mul_B{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T),arithtype(S))
Ac_mul_B!(similar(B,TS,(size(A,2),size(B,2))),A,B)
end

Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper('C', 'N', A, B)
Ac_mul_B{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul('C', 'N', A, B)
Ac_mul_B!{T,S,R}(C::StridedMatrix{R}, A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul(C, 'C', 'N', A, B)

A_mul_Bc{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = A_mul_Bt(A, B)
A_mul_Bc!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = A_mul_Bt!(C, A, B)
function A_mul_Bc{T<:BlasComplex}(A::StridedMatrix{T}, B::StridedMatrix{T})
is(A, B) ? herk_wrapper('N', A) : gemm_wrapper('N', 'C', A, B)
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = is(A,B) ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
Ac_mul_B!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'C', 'N', A, B)

A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B)
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt!(C, A, B)
function A_mul_Bc{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S})
TS = promote_type(arithtype(T),arithtype(S))
A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B)
end
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper(C, 'N', 'C', A, B)
A_mul_Bc{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul('N', 'C', A, B)
A_mul_Bc!{T,S,R}(C::StridedMatrix{R}, A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul(C, 'N', 'C', A, B)
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = is(A,B) ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
A_mul_Bc!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'N', 'C', A, B)

Ac_mul_Bc{T<:BlasFloat}(A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper('C', 'C', A, B)
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper(C, 'C', 'C', A, B)
Ac_mul_Bt{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul('C', 'C', A, B)
Ac_mul_Bt!{T,S,R}(C::StridedMatrix{R}, A::StridedMatrix{T}, B::StridedMatrix{S}) = generic_matmatmul(C, 'C', 'C', A, B)
Ac_mul_Bc{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S}) = Ac_mul_Bc!(similar(B,promote_type(arithtype(T),arithtype(S)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::StridedMatrix{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
Ac_mul_Bc!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'C', 'C', A, B)
Ac_mul_Bt{T,S}(A::StridedMatrix{T}, B::StridedMatrix{S}) = Ac_mul_Bt(similar(B, promote_type(arithtype(A),arithtype(B)), (size(A,2), size(B,1))), A, B)
Ac_mul_Bt!(C::StridedMatrix, A::StridedMatrix, B::StridedMatrix) = generic_matmatmul!(C, 'C', 'T', A, B)

# Supporting functions for matrix multiplication

Expand All @@ -158,8 +199,8 @@ function copytri!(A::StridedMatrix, uplo::Char, conjugate::Bool=false)
A
end

function gemv{T<:BlasFloat}(y::StridedVector{T}, tA::Char, A::StridedMatrix{T}, x::StridedVector{T})
stride(A, 1)==1 || return generic_matvecmul(y, tA, A, x)
function gemv!{T<:BlasFloat}(y::StridedVector{T}, tA::Char, A::StridedMatrix{T}, x::StridedVector{T})
stride(A, 1)==1 || return generic_matvecmul!(y, tA, A, x)

if tA != 'N'
(nA, mA) = size(A)
Expand All @@ -174,41 +215,44 @@ function gemv{T<:BlasFloat}(y::StridedVector{T}, tA::Char, A::StridedMatrix{T},
return BLAS.gemv!(tA, one(T), A, x, zero(T), y)
end

function syrk_wrapper{T<:BlasFloat}(tA::Char, A::StridedMatrix{T})
function syrk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedMatrix{T})
nC = chksquare(C)
if tA == 'T'
(nA, mA) = size(A)
tAt = 'N'
else
(mA, nA) = size(A)
tAt = 'T'
end
nC == mA || throw(DimensionMismatch("output matrix has size: $(nC), but should have size $(mA)"))
if mA == 0 || nA == 0; return C; end
if mA == 2 && nA == 2; return matmul2x2!(C,tA,tAt,A,A); end
if mA == 3 && nA == 3; return matmul3x3!(C,tA,tAt,A,A); end

if mA == 0 || nA == 0; return zeros(T, mA, mA); end
if mA == 2 && nA == 2; return matmul2x2(tA,tAt,A,A); end
if mA == 3 && nA == 3; return matmul3x3(tA,tAt,A,A); end

stride(A, 1) == 1 || (return generic_matmatmul(tA, tAt, A, A))
copytri!(BLAS.syrk('U', tA, one(T), A), 'U')
stride(A, 1) == 1 || (return generic_matmatmul!(C, tA, tAt, A, A))
copytri!(BLAS.syrk!('U', tA, one(T), A, zero(T), C), 'U')
end

function herk_wrapper{T<:BlasFloat}(tA::Char, A::StridedMatrix{T})
function herk_wrapper!{T<:BlasFloat}(C::StridedMatrix{T}, tA::Char, A::StridedMatrix{T})
nC = chksquare(C)
if tA == 'C'
(nA, mA) = size(A)
tAt = 'N'
else
(mA, nA) = size(A)
tAt = 'C'
end
nC == mA || throw(DimensionMismatch("output matrix has size: $(nC), but should have size $(mA)"))
if mA == 0 || nA == 0; return C; end
if mA == 2 && nA == 2; return matmul2x2!(C,tA,tAt,A,A); end
if mA == 3 && nA == 3; return matmul3x3!(C,tA,tAt,A,A); end

if mA == 2 && nA == 2; return matmul2x2(tA,tAt,A,A); end
if mA == 3 && nA == 3; return matmul3x3(tA,tAt,A,A); end

stride(A, 1) == 1 || (return generic_matmatmul(tA, tAt, A, A))
stride(A, 1) == 1 || (return generic_matmatmul!(C,tA, tAt, A, A))

# Result array does not need to be initialized as long as beta==0
# C = Array(T, mA, mA)

copytri!(BLAS.herk('U', tA, one(T), A), 'U', true)
copytri!(BLAS.herk!('U', tA, one(T), A, zero(T), C), 'U', true)
end

function gemm_wrapper{T<:BlasFloat}(tA::Char, tB::Char,
Expand All @@ -217,10 +261,10 @@ function gemm_wrapper{T<:BlasFloat}(tA::Char, tB::Char,
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, T, mA, nB)
gemm_wrapper(C, tA, tB, A, B)
gemm_wrapper!(C, tA, tB, A, B)
end

function gemm_wrapper{T<:BlasFloat}(C::StridedVecOrMat{T}, tA::Char, tB::Char,
function gemm_wrapper!{T<:BlasFloat}(C::StridedVecOrMat{T}, tA::Char, tB::Char,
A::StridedVecOrMat{T},
B::StridedMatrix{T})
mA, nA = lapack_size(tA, A)
Expand All @@ -229,10 +273,10 @@ function gemm_wrapper{T<:BlasFloat}(C::StridedVecOrMat{T}, tA::Char, tB::Char,
nA==mB || throw(DimensionMismatch("*"))

if mA == 0 || nA == 0 || nB == 0; return zeros(T, mA, nB); end
if mA == 2 && nA == 2 && nB == 2; return matmul2x2(C,tA,tB,A,B); end
if mA == 3 && nA == 3 && nB == 3; return matmul3x3(C,tA,tB,A,B); end
if mA == 2 && nA == 2 && nB == 2; return matmul2x2!(C,tA,tB,A,B); end
if mA == 3 && nA == 3 && nB == 3; return matmul3x3!(C,tA,tB,A,B); end

stride(A, 1)==stride(B, 1)==1 || (return generic_matmatmul(C, tA, tB, A, B))
stride(A, 1)==stride(B, 1)==1 || (return generic_matmatmul!(C, tA, tB, A, B))
BLAS.gemm!(tA, tB, one(T), A, B, zero(T), C)
end

Expand Down Expand Up @@ -263,18 +307,9 @@ end
# call BLAS, and convert back to required type.

# NOTE: the generic version is also called as fallback for
# strides != 1 cases in libalg_blas.jl
(*){T,S}(A::AbstractMatrix{T}, B::AbstractVector{S}) = generic_matvecmul('N', A, B)

arithtype(T) = T
arithtype(::Type{Bool}) = Int
# strides != 1 cases

function generic_matvecmul{T,S}(tA::Char, A::AbstractMatrix{T}, B::AbstractVector{S})
C = similar(B, promote_type(arithtype(T),arithtype(S)), size(A, tA=='N' ? 1 : 2))
generic_matvecmul(C, tA, A, B)
end

function generic_matvecmul{T,S,R}(C::AbstractVector{R}, tA, A::AbstractMatrix{T}, B::AbstractVector{S})
function generic_matvecmul!{T,S,R}(C::AbstractVector{R}, tA, A::AbstractMatrix{T}, B::AbstractVector{S})
mB = length(B)
mA, nA = lapack_size(tA, A)
mB==nA || throw(DimensionMismatch("*"))
Expand Down Expand Up @@ -314,32 +349,26 @@ function generic_matvecmul{T,S,R}(C::AbstractVector{R}, tA, A::AbstractMatrix{T}
C
end

(*){T,S}(A::AbstractVector{S}, B::AbstractMatrix{T}) = reshape(A,length(A),1)*B

# NOTE: the generic version is also called as fallback for strides != 1 cases
# in libalg_blas.jl
(*){T,S}(A::AbstractVecOrMat{T}, B::AbstractMatrix{S}) = generic_matmatmul('N', 'N', A, B)

function generic_matmatmul{T,S}(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, promote_type(arithtype(T),arithtype(S)), mA, nB)
generic_matmatmul(C, tA, tB, A, B)
generic_matmatmul!(C, tA, tB, A, B)
end

const tilebufsize = 10800 # Approximately 32k/3
const Abuf = Array(Uint8, tilebufsize)
const Bbuf = Array(Uint8, tilebufsize)
const Cbuf = Array(Uint8, tilebufsize)

function generic_matmatmul{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
function generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
mB==nA || throw(DimensionMismatch("*"))
if size(C,1) != mA || size(C,2) != nB; throw(DimensionMismatch("*")); end

if mA == nA == nB == 2; return matmul2x2(C, tA, tB, A, B); end
if mA == nA == nB == 3; return matmul3x3(C, tA, tB, A, B); end
if mA == nA == nB == 2; return matmul2x2!(C, tA, tB, A, B); end
if mA == nA == nB == 3; return matmul3x3!(C, tA, tB, A, B); end

@inbounds begin
if isbits(R)
Expand Down Expand Up @@ -483,10 +512,10 @@ end

# multiply 2x2 matrices
function matmul2x2{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
matmul2x2(similar(B, promote_type(T,S), 2, 2), tA, tB, A, B)
matmul2x2!(similar(B, promote_type(T,S), 2, 2), tA, tB, A, B)
end

function matmul2x2{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
function matmul2x2!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
if tA == 'T'
A11 = A[1,1]; A12 = A[2,1]; A21 = A[1,2]; A22 = A[2,2]
elseif tA == 'C'
Expand All @@ -510,10 +539,10 @@ end

# Multiply 3x3 matrices
function matmul3x3{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
matmul3x3(similar(B, promote_type(T,S), 3, 3), tA, tB, A, B)
matmul3x3!(similar(B, promote_type(T,S), 3, 3), tA, tB, A, B)
end

function matmul3x3{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
function matmul3x3!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
if tA == 'T'
A11 = A[1,1]; A12 = A[2,1]; A13 = A[3,1];
A21 = A[1,2]; A22 = A[2,2]; A23 = A[3,2];
Expand Down