@@ -493,52 +493,80 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
493
493
(\ )(A:: Union{QR,QRCompactWY,QRPivoted} , B:: Diagonal ) =
494
494
invoke (\ , Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)
495
495
496
- function kron (A:: Diagonal{T1} , B:: Diagonal{T2} ) where {T1<: Number , T2<: Number }
496
+
497
+ @inline function kron! (C:: AbstractMatrix{T} , A:: Diagonal , B:: Diagonal ) where T
498
+ fill! (C, zero (T))
497
499
valA = A. diag; nA = length (valA)
498
500
valB = B. diag; nB = length (valB)
499
- valC = Vector {typeof(zero(T1)*zero(T2))} (undef,nA* nB)
501
+ nC = checksquare (C)
502
+ @boundscheck nC == nA* nB ||
503
+ throw (DimensionMismatch (" expect C to be a $(nA* nB) x$(nA* nB) matrix, got size $(nC) x$(nC) " ))
504
+
500
505
@inbounds for i = 1 : nA, j = 1 : nB
501
- valC[(i- 1 )* nB+ j] = valA[i] * valB[j]
506
+ idx = (i- 1 )* nB+ j
507
+ C[idx, idx] = valA[i] * valB[j]
502
508
end
503
- return Diagonal (valC)
509
+ return C
504
510
end
505
511
506
- function kron (A:: Diagonal{T} , B:: AbstractMatrix{S} ) where {T<: Number , S<: Number }
512
+ function kron (A:: Diagonal{T1} , B:: Diagonal{T2} ) where {T1<: Number , T2<: Number }
513
+ valA = A. diag; nA = length (valA)
514
+ valB = B. diag; nB = length (valB)
515
+ valC = Vector {typeof(zero(T1)*zero(T2))} (undef,nA* nB)
516
+ C = Diagonal (valC)
517
+ return @inbounds kron! (C, A, B)
518
+ end
519
+
520
+ @inline function kron! (C:: AbstractMatrix , A:: Diagonal , B:: AbstractMatrix )
507
521
Base. require_one_based_indexing (B)
508
- (mA, nA) = size (A); (mB, nB) = size (B)
509
- R = zeros (Base. promote_op (* , T, S), mA * mB, nA * nB)
522
+ (mA, nA) = size (A); (mB, nB) = size (B); (mC, nC) = size (C);
523
+ @boundscheck (mC, nC) == (mA * mB, nA * nB) ||
524
+ throw (DimensionMismatch (" expect C to be a $(mA * mB) x$(nA * nB) matrix, got size $(mC) x$(nC) " ))
510
525
m = 1
511
- for j = 1 : nA
526
+ @inbounds for j = 1 : nA
512
527
A_jj = A[j,j]
513
528
for k = 1 : nB
514
529
for l = 1 : mB
515
- R [m] = A_jj * B[l,k]
530
+ C [m] = A_jj * B[l,k]
516
531
m += 1
517
532
end
518
533
m += (nA - 1 ) * mB
519
534
end
520
535
m += mB
521
536
end
522
- return R
537
+ return C
523
538
end
524
539
525
- function kron ( A:: AbstractMatrix{T} , B:: Diagonal{S} ) where {T <: Number , S <: Number }
540
+ @inline function kron! (C :: AbstractMatrix , A:: AbstractMatrix , B:: Diagonal )
526
541
require_one_based_indexing (A)
527
- (mA, nA) = size (A); (mB, nB) = size (B)
528
- R = zeros (promote_op (* , T, S), mA * mB, nA * nB)
542
+ (mA, nA) = size (A); (mB, nB) = size (B); (mC, nC) = size (C);
543
+ @boundscheck (mC, nC) == (mA * mB, nA * nB) ||
544
+ throw (DimensionMismatch (" expect C to be a $(mA * mB) x$(nA * nB) matrix, got size $(mC) x$(nC) " ))
529
545
m = 1
530
- for j = 1 : nA
546
+ @inbounds for j = 1 : nA
531
547
for l = 1 : mB
532
548
Bll = B[l,l]
533
549
for k = 1 : mA
534
- R [m] = A[k,j] * Bll
550
+ C [m] = A[k,j] * Bll
535
551
m += nB
536
552
end
537
553
m += 1
538
554
end
539
555
m -= nB
540
556
end
541
- return R
557
+ return C
558
+ end
559
+
560
+ function kron (A:: Diagonal{T} , B:: AbstractMatrix{S} ) where {T<: Number , S<: Number }
561
+ (mA, nA) = size (A); (mB, nB) = size (B)
562
+ R = zeros (Base. promote_op (* , T, S), mA * mB, nA * nB)
563
+ return @inbounds kron! (R, A, B)
564
+ end
565
+
566
+ function kron (A:: AbstractMatrix{T} , B:: Diagonal{S} ) where {T<: Number , S<: Number }
567
+ (mA, nA) = size (A); (mB, nB) = size (B)
568
+ R = zeros (promote_op (* , T, S), mA * mB, nA * nB)
569
+ return @inbounds kron! (R, A, B)
542
570
end
543
571
544
572
conj (D:: Diagonal ) = Diagonal (conj (D. diag))
0 commit comments