Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add in-place kron #31069

Merged
merged 1 commit into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Build system changes

New library functions
---------------------

* New function `Base.kron!` and corresponding overloads for various matrix types for performing Kronecker product in-place. ([#31069]).

New library features
--------------------
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ export
adjoint,
transpose,
kron,
kron!,

# bitarrays
falses,
Expand Down
2 changes: 2 additions & 0 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ for op in (:+, :*, :&, :|, :xor, :min, :max, :kron)
end
end

function kron! end

Comment on lines +545 to +546
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this good for? For stack traces or something like that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because kron is defined in Base (https://github.com/JuliaLang/julia/blob/master/base/operators.jl#L533), thus all other stdlibs need to import Base: kron, I think it'd be more consistent to define the generic function in Base instead of LinearAlgebra.

const var"'" = adjoint

"""
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ Base.inv(::AbstractMatrix)
LinearAlgebra.pinv
LinearAlgebra.nullspace
Base.kron
Base.kron!
LinearAlgebra.exp(::StridedMatrix{<:LinearAlgebra.BlasFloat})
Base.:^(::AbstractMatrix, ::Number)
Base.:^(::Number, ::AbstractMatrix)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Base: \, /, *, ^, +, -, ==
import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech,
asin, asinh, atan, atanh, axes, big, broadcast, ceil, conj, convert, copy, copyto!, cos,
cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat,
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, length, log, map, ndims,
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims,
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
Expand Down
26 changes: 20 additions & 6 deletions stdlib/LinearAlgebra/src/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,29 @@ qr(A::BitMatrix) = qr(float(A))

## kron

function kron(a::BitVector, b::BitVector)
@inline function kron!(R::BitVector, a::BitVector, b::BitVector)
m = length(a)
n = length(b)
R = falses(n * m)
@boundscheck length(R) == n*m || throw(DimensionMismatch())
Rc = R.chunks
bc = b.chunks
for j = 1:m
a[j] && Base.copy_chunks!(Rc, (j-1)*n+1, bc, 1, n)
end
R
return R
end

function kron(a::BitMatrix, b::BitMatrix)
function kron(a::BitVector, b::BitVector)
m = length(a)
n = length(b)
R = falses(n * m)
return @inbounds kron!(R, a, b)
end

function kron!(R::BitMatrix, a::BitMatrix, b::BitMatrix)
mA,nA = size(a)
mB,nB = size(b)
R = falses(mA*mB, nA*nB)
@boundscheck size(R) == (mA*mB, nA*nB) || throw(DimensionMismatch())

for i = 1:mA
ri = (1:mB) .+ ((i-1)*mB)
Expand All @@ -118,7 +125,14 @@ function kron(a::BitMatrix, b::BitMatrix)
end
end
end
R
return R
end

function kron(a::BitMatrix, b::BitMatrix)
mA,nA = size(a)
mB,nB = size(b)
R = falses(mA*mB, nA*nB)
return @inbounds kron!(R, a, b)
end

## Structure query functions
Expand Down
46 changes: 37 additions & 9 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,29 @@ function tr(A::Matrix{T}) where T
t
end

"""
kron!(C, A, B)

`kron!` is the in-place version of [`kron`](@ref). Computes `kron(A, B)` and stores the result in `C`
overwriting the existing value of `C`.

!!! tip
Bounds checking can be disabled by [`@inbounds`](@ref), but you need to take care of the shape
of `C`, `A`, `B` yourself.
"""
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
require_one_based_indexing(A, B)
@boundscheck (size(C) == (size(A,1)*size(B,1), size(A,2)*size(B,2))) || throw(DimensionMismatch())
m = 0
@inbounds for j = 1:size(A,2), l = 1:size(B,2), i = 1:size(A,1)
Aij = A[i,j]
for k = 1:size(B,1)
C[m += 1] = Aij*B[k,l]
end
end
return C
end

"""
kron(A, B)

Expand Down Expand Up @@ -383,18 +406,23 @@ julia> reshape(kron(v,w), (length(w), length(v)))
```
"""
function kron(a::AbstractMatrix{T}, b::AbstractMatrix{S}) where {T,S}
require_one_based_indexing(a, b)
R = Matrix{promote_op(*,T,S)}(undef, size(a,1)*size(b,1), size(a,2)*size(b,2))
m = 0
@inbounds for j = 1:size(a,2), l = 1:size(b,2), i = 1:size(a,1)
aij = a[i,j]
for k = 1:size(b,1)
R[m += 1] = aij*b[k,l]
end
end
R
return @inbounds kron!(R, a, b)
end

kron!(c::AbstractVecOrMat, a::AbstractVecOrMat, b::Number) = mul!(c, a, b)

Base.@propagate_inbounds function kron!(c::AbstractVector, a::AbstractVector, b::AbstractVector)
C = reshape(c, length(a)*length(b), 1)
A = reshape(a ,length(a), 1)
B = reshape(b, length(b), 1)
kron!(C, A, B)
return c
end

Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractMatrix, b::AbstractVector) = kron!(C, a, reshape(b, length(b), 1))
Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractVector, b::AbstractMatrix) = kron!(C, reshape(a, length(a), 1), b)

kron(a::Number, b::Union{Number, AbstractVecOrMat}) = a * b
kron(a::AbstractVecOrMat, b::Number) = a * b
kron(a::AbstractVector, b::AbstractVector) = vec(kron(reshape(a ,length(a), 1), reshape(b, length(b), 1)))
Expand Down
60 changes: 44 additions & 16 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -493,52 +493,80 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
(\)(A::Union{QR,QRCompactWY,QRPivoted}, B::Diagonal) =
invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)

function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number}

@inline function kron!(C::AbstractMatrix{T}, A::Diagonal, B::Diagonal) where T
fill!(C, zero(T))
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB)
nC = checksquare(C)
@boundscheck nC == nA*nB ||
throw(DimensionMismatch("expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))

@inbounds for i = 1:nA, j = 1:nB
valC[(i-1)*nB+j] = valA[i] * valB[j]
idx = (i-1)*nB+j
C[idx, idx] = valA[i] * valB[j]
end
return Diagonal(valC)
return C
end

function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number}
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB)
C = Diagonal(valC)
return @inbounds kron!(C, A, B)
end

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
Base.require_one_based_indexing(B)
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
m = 1
for j = 1:nA
@inbounds for j = 1:nA
A_jj = A[j,j]
for k = 1:nB
for l = 1:mB
R[m] = A_jj * B[l,k]
C[m] = A_jj * B[l,k]
m += 1
end
m += (nA - 1) * mB
end
m += mB
end
return R
return C
end

function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also have a @propagate_inbounds, I suppose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, yes!

require_one_based_indexing(A)
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
m = 1
for j = 1:nA
@inbounds for j = 1:nA
for l = 1:mB
Bll = B[l,l]
for k = 1:mA
R[m] = A[k,j] * Bll
C[m] = A[k,j] * Bll
m += nB
end
m += 1
end
m -= nB
end
return R
return C
end

function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
return @inbounds kron!(R, A, B)
end

function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
return @inbounds kron!(R, A, B)
end

conj(D::Diagonal) = Diagonal(conj(D.diag))
Expand Down
2 changes: 1 addition & 1 deletion stdlib/SparseArrays/src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import Base: acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
tand, tanh, trunc, abs, abs2,
broadcast, ceil, complex, conj, convert, copy, copyto!, adjoint,
exp, expm1, findall, findmax, findmin, float, getindex,
vcat, hcat, hvcat, cat, imag, argmax, kron, length, log, log1p, max, min,
vcat, hcat, hvcat, cat, imag, argmax, kron, kron!, length, log, log1p, max, min,
maximum, minimum, one, promote_eltype, real, reshape, rot180,
rotl90, rotr90, round, setindex!, similar, size, transpose,
vec, permute!, map, map!, Array, diff, circshift!, circshift
Expand Down
75 changes: 61 additions & 14 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1295,16 +1295,21 @@ function opnormestinv(A::AbstractSparseMatrixCSC{T}, t::Integer = min(2,maximum(
end

## kron

# sparse matrix ⊗ sparse matrix
function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S2}) where {T1,S1,T2,S2}
@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
nnzC = nnz(A)*nnz(B)
mA, nA = size(A); mB, nB = size(B)
mC, nC = mA*mB, nA*nB
colptrC = Vector{promote_type(S1,S2)}(undef, nC+1)
rowvalC = Vector{promote_type(S1,S2)}(undef, nnzC)
nzvalC = Vector{typeof(one(T1)*one(T2))}(undef, nnzC)
colptrC[1] = 1

rowvalC = rowvals(C)
nzvalC = nonzeros(C)
colptrC = getcolptr(C)

@boundscheck begin
length(colptrC) == nC+1 || throw(DimensionMismatch("expect C to be preallocated with $(nC+1) colptrs "))
length(rowvalC) == nnzC || throw(DimensionMismatch("expect C to be preallocated with $(nnzC) rowvals"))
length(nzvalC) == nnzC || throw(DimensionMismatch("expect C to be preallocated with $(nnzC) nzvals"))
end

col = 1
@inbounds for j = 1:nA
startA = getcolptr(A)[j]
Expand All @@ -1328,7 +1333,43 @@ function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S
end
end
end
return SparseMatrixCSC(mC, nC, colptrC, rowvalC, nzvalC)
return C
end

@inline function kron!(z::SparseVector, x::SparseVector, y::SparseVector)
nnzx = nnz(x); nnzy = nnz(y); nnzz = nnz(z);
nzind = nonzeroinds(z)
nzval = nonzeros(z)

@boundscheck begin
nnzval = length(nzval); nnzind = length(nzind)
nnzz = nnzx*nnzy
nnzval == nnzz || throw(DimensionMismatch("expect z to be preallocated with $nnzz nonzeros"))
nnzind == nnzz || throw(DimensionMismatch("expect z to be preallocated with $nnzz nonzeros"))
end

@inbounds for i = 1:nnzx, j = 1:nnzy
this_ind = (i-1)*nnzy+j
nzind[this_ind] = (nonzeroinds(x)[i]-1)*length(y) + nonzeroinds(y)[j]
nzval[this_ind] = nonzeros(x)[i] * nonzeros(y)[j]
end
return z
end

# sparse matrix ⊗ sparse matrix
function kron(A::AbstractSparseMatrixCSC{T1,S1}, B::AbstractSparseMatrixCSC{T2,S2}) where {T1,S1,T2,S2}
nnzC = nnz(A)*nnz(B)
mA, nA = size(A); mB, nB = size(B)
mC, nC = mA*mB, nA*nB
Tv = typeof(one(T1)*one(T2))
Ti = promote_type(S1,S2)
colptrC = Vector{Ti}(undef, nC+1)
rowvalC = Vector{Ti}(undef, nnzC)
nzvalC = Vector{Tv}(undef, nnzC)
colptrC[1] = 1
# skip sparse_check
C = SparseMatrixCSC{Tv, Ti}(mC, nC, colptrC, rowvalC, nzvalC)
return @inbounds kron!(C, A, B)
end

# sparse vector ⊗ sparse vector
Expand All @@ -1337,27 +1378,33 @@ function kron(x::SparseVector{T1,S1}, y::SparseVector{T2,S2}) where {T1,S1,T2,S2
nnzz = nnzx*nnzy # number of nonzeros in new vector
nzind = Vector{promote_type(S1,S2)}(undef, nnzz) # the indices of nonzeros
nzval = Vector{typeof(one(T1)*one(T2))}(undef, nnzz) # the values of nonzeros
@inbounds for i = 1:nnzx, j = 1:nnzy
this_ind = (i-1)*nnzy+j
nzind[this_ind] = (nonzeroinds(x)[i]-1)*length(y::SparseVector) + nonzeroinds(y)[j]
nzval[this_ind] = nonzeros(x)[i] * nonzeros(y)[j]
end
return SparseVector(length(x::SparseVector)*length(y::SparseVector), nzind, nzval)
z = SparseVector(length(x)*length(y), nzind, nzval)
return @inbounds kron!(z, x, y)
end

# sparse matrix ⊗ sparse vector & vice versa
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, x::SparseVector) = kron!(C, A, SparseMatrixCSC(x))
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, x::SparseVector, A::AbstractSparseMatrixCSC) = kron!(C, SparseMatrixCSC(x), A)

kron(A::AbstractSparseMatrixCSC, x::SparseVector) = kron(A, SparseMatrixCSC(x))
kron(x::SparseVector, A::AbstractSparseMatrixCSC) = kron(SparseMatrixCSC(x), A)

# sparse vec/mat ⊗ vec/mat and vice versa
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector,AbstractSparseMatrixCSC}, B::VecOrMat) = kron!(C, A, sparse(B))
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC}) = kron!(C, sparse(A), B)

kron(A::Union{SparseVector,AbstractSparseMatrixCSC}, B::VecOrMat) = kron(A, sparse(B))
kron(A::VecOrMat, B::Union{SparseVector,AbstractSparseMatrixCSC}) = kron(sparse(A), B)

# sparse vec/mat ⊗ Diagonal and vice versa
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Diagonal{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}}) where {T<:Number, S<:Number} = kron!(C, sparse(A), B)
Base.@propagate_inbounds kron!(C::SparseMatrixCSC, A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron!(C, A, sparse(B))

kron(A::Diagonal{T}, B::Union{SparseVector{S}, AbstractSparseMatrixCSC{S}}) where {T<:Number, S<:Number} = kron(sparse(A), B)
kron(A::Union{SparseVector{T}, AbstractSparseMatrixCSC{T}}, B::Diagonal{S}) where {T<:Number, S<:Number} = kron(A, sparse(B))

# sparse outer product
kron!(C::SparseMatrixCSC, A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = broadcast!(*, C, A, B)
kron(A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = A .* B

## det, inv, cond
Expand Down