Skip to content

Commit 23bfba1

Browse files
committed
add inplace kron
1 parent 8ef29e6 commit 23bfba1

File tree

8 files changed

+107
-33
lines changed

8 files changed

+107
-33
lines changed

NEWS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Build system changes
2929

3030
New library functions
3131
---------------------
32-
32+
* New function `Base.kron!` and corresponding overloads for various matrix types for performing Kronecker product in-place. ([#31069]).
3333

3434
New library features
3535
--------------------

base/exports.jl

+1
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ export
463463
adjoint,
464464
transpose,
465465
kron,
466+
kron!,
466467

467468
# bitarrays
468469
falses,

base/operators.jl

+2
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,8 @@ for op in (:+, :*, :&, :|, :xor, :min, :max, :kron)
542542
end
543543
end
544544

545+
function kron! end
546+
545547
const var"'" = adjoint
546548

547549
"""

stdlib/LinearAlgebra/docs/src/index.md

+1
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ Base.inv(::AbstractMatrix)
409409
LinearAlgebra.pinv
410410
LinearAlgebra.nullspace
411411
Base.kron
412+
Base.kron!
412413
LinearAlgebra.exp(::StridedMatrix{<:LinearAlgebra.BlasFloat})
413414
Base.:^(::AbstractMatrix, ::Number)
414415
Base.:^(::Number, ::AbstractMatrix)

stdlib/LinearAlgebra/src/LinearAlgebra.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import Base: \, /, *, ^, +, -, ==
1111
import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech,
1212
asin, asinh, atan, atanh, axes, big, broadcast, ceil, conj, convert, copy, copyto!, cos,
1313
cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat,
14-
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, length, log, map, ndims,
14+
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims,
1515
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
1616
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
1717
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec

stdlib/LinearAlgebra/src/bitarray.jl

+20-6
Original file line numberDiff line numberDiff line change
@@ -92,22 +92,29 @@ qr(A::BitMatrix) = qr(float(A))
9292

9393
## kron
9494

95-
function kron(a::BitVector, b::BitVector)
95+
@inline function kron!(R::BitVector, a::BitVector, b::BitVector)
9696
m = length(a)
9797
n = length(b)
98-
R = falses(n * m)
98+
@boundscheck length(R) == n*m || throw(DimensionMismatch())
9999
Rc = R.chunks
100100
bc = b.chunks
101101
for j = 1:m
102102
a[j] && Base.copy_chunks!(Rc, (j-1)*n+1, bc, 1, n)
103103
end
104-
R
104+
return R
105105
end
106106

107-
function kron(a::BitMatrix, b::BitMatrix)
107+
function kron(a::BitVector, b::BitVector)
108+
m = length(a)
109+
n = length(b)
110+
R = falses(n * m)
111+
return @inbounds kron!(R, a, b)
112+
end
113+
114+
function kron!(R::BitMatrix, a::BitMatrix, b::BitMatrix)
108115
mA,nA = size(a)
109116
mB,nB = size(b)
110-
R = falses(mA*mB, nA*nB)
117+
@boundscheck size(R) == (mA*mB, nA*nB) || throw(DimensionMismatch())
111118

112119
for i = 1:mA
113120
ri = (1:mB) .+ ((i-1)*mB)
@@ -118,7 +125,14 @@ function kron(a::BitMatrix, b::BitMatrix)
118125
end
119126
end
120127
end
121-
R
128+
return R
129+
end
130+
131+
function kron(a::BitMatrix, b::BitMatrix)
132+
mA,nA = size(a)
133+
mB,nB = size(b)
134+
R = falses(mA*mB, nA*nB)
135+
return @inbounds kron!(R, a, b)
122136
end
123137

124138
## Structure query functions

stdlib/LinearAlgebra/src/dense.jl

+37-9
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,29 @@ function tr(A::Matrix{T}) where T
336336
t
337337
end
338338

339+
"""
340+
kron!(C, A, B)
341+
342+
`kron!` is the in-place version of [`kron`](@ref). Computes `kron(A, B)` and stores the result in `C`
343+
overwriting the existing value of `C`.
344+
345+
!!! tip
346+
Bounds checking can be disabled by [`@inbounds`](@ref), but you need to take care of the shape
347+
of `C`, `A`, `B` yourself.
348+
"""
349+
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
350+
require_one_based_indexing(A, B)
351+
@boundscheck (size(C) == (size(A,1)*size(B,1), size(A,2)*size(B,2))) || throw(DimensionMismatch())
352+
m = 0
353+
@inbounds for j = 1:size(A,2), l = 1:size(B,2), i = 1:size(A,1)
354+
Aij = A[i,j]
355+
for k = 1:size(B,1)
356+
C[m += 1] = Aij*B[k,l]
357+
end
358+
end
359+
return C
360+
end
361+
339362
"""
340363
kron(A, B)
341364
@@ -383,18 +406,23 @@ julia> reshape(kron(v,w), (length(w), length(v)))
383406
```
384407
"""
385408
function kron(a::AbstractMatrix{T}, b::AbstractMatrix{S}) where {T,S}
386-
require_one_based_indexing(a, b)
387409
R = Matrix{promote_op(*,T,S)}(undef, size(a,1)*size(b,1), size(a,2)*size(b,2))
388-
m = 0
389-
@inbounds for j = 1:size(a,2), l = 1:size(b,2), i = 1:size(a,1)
390-
aij = a[i,j]
391-
for k = 1:size(b,1)
392-
R[m += 1] = aij*b[k,l]
393-
end
394-
end
395-
R
410+
return @inbounds kron!(R, a, b)
396411
end
397412

413+
kron!(c::AbstractVecOrMat, a::AbstractVecOrMat, b::Number) = mul!(c, a, b)
414+
415+
Base.@propagate_inbounds function kron!(c::AbstractVector, a::AbstractVector, b::AbstractVector)
416+
C = reshape(c, length(a)*length(b), 1)
417+
A = reshape(a ,length(a), 1)
418+
B = reshape(b, length(b), 1)
419+
kron!(C, A, B)
420+
return c
421+
end
422+
423+
Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractMatrix, b::AbstractVector) = kron!(C, a, reshape(b, length(b), 1))
424+
Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractVector, b::AbstractMatrix) = kron!(C, reshape(a, length(a), 1), b)
425+
398426
kron(a::Number, b::Union{Number, AbstractVecOrMat}) = a * b
399427
kron(a::AbstractVecOrMat, b::Number) = a * b
400428
kron(a::AbstractVector, b::AbstractVector) = vec(kron(reshape(a ,length(a), 1), reshape(b, length(b), 1)))

stdlib/LinearAlgebra/src/diagonal.jl

+44-16
Original file line numberDiff line numberDiff line change
@@ -493,52 +493,80 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
493493
(\)(A::Union{QR,QRCompactWY,QRPivoted}, B::Diagonal) =
494494
invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)
495495

496-
function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number}
496+
497+
@inline function kron!(C::AbstractMatrix{T}, A::Diagonal, B::Diagonal) where T
498+
fill!(C, zero(T))
497499
valA = A.diag; nA = length(valA)
498500
valB = B.diag; nB = length(valB)
499-
valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB)
501+
nC = checksquare(C)
502+
@boundscheck nC == nA*nB ||
503+
throw(DimensionMismatch("expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))
504+
500505
@inbounds for i = 1:nA, j = 1:nB
501-
valC[(i-1)*nB+j] = valA[i] * valB[j]
506+
idx = (i-1)*nB+j
507+
C[idx, idx] = valA[i] * valB[j]
502508
end
503-
return Diagonal(valC)
509+
return C
504510
end
505511

506-
function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
512+
function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number}
513+
valA = A.diag; nA = length(valA)
514+
valB = B.diag; nB = length(valB)
515+
valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB)
516+
C = Diagonal(valC)
517+
return @inbounds kron!(C, A, B)
518+
end
519+
520+
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
507521
Base.require_one_based_indexing(B)
508-
(mA, nA) = size(A); (mB, nB) = size(B)
509-
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
522+
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
523+
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
524+
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
510525
m = 1
511-
for j = 1:nA
526+
@inbounds for j = 1:nA
512527
A_jj = A[j,j]
513528
for k = 1:nB
514529
for l = 1:mB
515-
R[m] = A_jj * B[l,k]
530+
C[m] = A_jj * B[l,k]
516531
m += 1
517532
end
518533
m += (nA - 1) * mB
519534
end
520535
m += mB
521536
end
522-
return R
537+
return C
523538
end
524539

525-
function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
540+
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal)
526541
require_one_based_indexing(A)
527-
(mA, nA) = size(A); (mB, nB) = size(B)
528-
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
542+
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
543+
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
544+
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
529545
m = 1
530-
for j = 1:nA
546+
@inbounds for j = 1:nA
531547
for l = 1:mB
532548
Bll = B[l,l]
533549
for k = 1:mA
534-
R[m] = A[k,j] * Bll
550+
C[m] = A[k,j] * Bll
535551
m += nB
536552
end
537553
m += 1
538554
end
539555
m -= nB
540556
end
541-
return R
557+
return C
558+
end
559+
560+
function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
561+
(mA, nA) = size(A); (mB, nB) = size(B)
562+
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
563+
return @inbounds kron!(R, A, B)
564+
end
565+
566+
function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
567+
(mA, nA) = size(A); (mB, nB) = size(B)
568+
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
569+
return @inbounds kron!(R, A, B)
542570
end
543571

544572
conj(D::Diagonal) = Diagonal(conj(D.diag))

0 commit comments

Comments
 (0)