@@ -396,82 +396,120 @@ function lmul!(D::Diagonal, T::Tridiagonal)
396
396
return T
397
397
end
398
398
399
- function __muldiag! (out, D:: Diagonal , B, _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
399
+ @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B, _add:: MulAddMul )
400
+ @inbounds for j in axes (B, 2 )
401
+ @simd for i in axes (B, 1 )
402
+ _modify! (_add, D. diag[i] * B[i,j], out, (i,j))
403
+ end
404
+ end
405
+ out
406
+ end
407
+ _maybe_unwrap_tri (out, A) = out, A
408
+ _maybe_unwrap_tri (out:: UpperTriangular , A:: UpperOrUnitUpperTriangular ) = parent (out), parent (A)
409
+ _maybe_unwrap_tri (out:: LowerTriangular , A:: LowerOrUnitLowerTriangular ) = parent (out), parent (A)
410
+ @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , _add:: MulAddMul )
411
+ isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular}
412
+ # if both B and out have the same upper/lower triangular structure,
413
+ # we may directly read and write from the parents
414
+ out_maybeparent, B_maybeparent = _maybe_unwrap_tri (out, B)
415
+ for j in axes (B, 2 )
416
+ if isunit
417
+ _modify! (_add, D. diag[j] * B[j,j], out, (j,j))
418
+ end
419
+ rowrange = B isa UpperOrUnitUpperTriangular ? (1 : min (j- isunit, size (B,1 ))) : (j+ isunit: size (B,1 ))
420
+ @inbounds @simd for i in rowrange
421
+ _modify! (_add, D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
422
+ end
423
+ end
424
+ out
425
+ end
426
+ function __muldiag! (out, D:: Diagonal , B, _add:: MulAddMul )
400
427
require_one_based_indexing (out, B)
401
428
alpha, beta = _add. alpha, _add. beta
402
429
if iszero (alpha)
403
430
_rmul_or_fill! (out, beta)
404
431
else
405
- if bis0
406
- @inbounds for j in axes (B, 2 )
407
- @simd for i in axes (B, 1 )
408
- out[i,j] = D. diag[i] * B[i,j] * alpha
409
- end
410
- end
411
- else
412
- @inbounds for j in axes (B, 2 )
413
- @simd for i in axes (B, 1 )
414
- out[i,j] = D. diag[i] * B[i,j] * alpha + out[i,j] * beta
415
- end
416
- end
417
- end
432
+ __muldiag_nonzeroalpha! (out, D, B, _add)
418
433
end
419
434
return out
420
435
end
421
- function __muldiag! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
436
+
437
+ @inline function __muldiag_nonzeroalpha! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
438
+ beta = _add. beta
439
+ _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
440
+ @inbounds for j in axes (A, 2 )
441
+ dja = _add (D. diag[j])
442
+ @simd for i in axes (A, 1 )
443
+ _modify! (_add_aisone, A[i,j] * dja, out, (i,j))
444
+ end
445
+ end
446
+ out
447
+ end
448
+ @inline function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
449
+ isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular}
450
+ beta = _add. beta
451
+ # since alpha is multiplied to the diagonal element of D,
452
+ # we may skip alpha in the second multiplication by setting ais1 to true
453
+ _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
454
+ # if both A and out have the same upper/lower triangular structure,
455
+ # we may directly read and write from the parents
456
+ out_maybeparent, A_maybeparent = _maybe_unwrap_tri (out, A)
457
+ @inbounds for j in axes (A, 2 )
458
+ dja = _add (D. diag[j])
459
+ if isunit
460
+ _modify! (_add_aisone, A[j,j] * dja, out, (j,j))
461
+ end
462
+ rowrange = A isa UpperOrUnitUpperTriangular ? (1 : min (j- isunit, size (A,1 ))) : (j+ isunit: size (A,1 ))
463
+ @simd for i in rowrange
464
+ _modify! (_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
465
+ end
466
+ end
467
+ out
468
+ end
469
+ function __muldiag! (out, A, D:: Diagonal , _add:: MulAddMul )
422
470
require_one_based_indexing (out, A)
423
471
alpha, beta = _add. alpha, _add. beta
424
472
if iszero (alpha)
425
473
_rmul_or_fill! (out, beta)
426
474
else
427
- if bis0
428
- @inbounds for j in axes (A, 2 )
429
- dja = D. diag[j] * alpha
430
- @simd for i in axes (A, 1 )
431
- out[i,j] = A[i,j] * dja
432
- end
433
- end
434
- else
435
- @inbounds for j in axes (A, 2 )
436
- dja = D. diag[j] * alpha
437
- @simd for i in axes (A, 1 )
438
- out[i,j] = A[i,j] * dja + out[i,j] * beta
439
- end
440
- end
441
- end
475
+ __muldiag_nonzeroalpha! (out, A, D, _add)
442
476
end
443
477
return out
444
478
end
445
- function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
479
+
480
+ @inline function __muldiag_nonzeroalpha! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
446
481
d1 = D1. diag
447
482
d2 = D2. diag
483
+ outd = out. diag
484
+ @inbounds @simd for i in eachindex (d1, d2, outd)
485
+ _modify! (_add, d1[i] * d2[i], outd, i)
486
+ end
487
+ out
488
+ end
489
+ function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
448
490
alpha, beta = _add. alpha, _add. beta
449
491
if iszero (alpha)
450
492
_rmul_or_fill! (out. diag, beta)
451
493
else
452
- if bis0
453
- @inbounds @simd for i in eachindex (out. diag)
454
- out. diag[i] = d1[i] * d2[i] * alpha
455
- end
456
- else
457
- @inbounds @simd for i in eachindex (out. diag)
458
- out. diag[i] = d1[i] * d2[i] * alpha + out. diag[i] * beta
459
- end
460
- end
494
+ __muldiag_nonzeroalpha! (out, D1, D2, _add)
461
495
end
462
496
return out
463
497
end
464
- function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
465
- require_one_based_indexing (out)
466
- alpha, beta = _add. alpha, _add. beta
467
- mA = size (D1, 1 )
498
+ @inline function __muldiag_nonzeroalpha! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
468
499
d1 = D1. diag
469
500
d2 = D2. diag
501
+ @inbounds @simd for i in eachindex (d1, d2)
502
+ _modify! (_add, d1[i] * d2[i], out, (i,i))
503
+ end
504
+ out
505
+ end
506
+ function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1} ) where {ais1}
507
+ require_one_based_indexing (out)
508
+ alpha, beta = _add. alpha, _add. beta
470
509
_rmul_or_fill! (out, beta)
471
510
if ! iszero (alpha)
472
- @inbounds @simd for i in 1 : mA
473
- out[i,i] += d1[i] * d2[i] * alpha
474
- end
511
+ _add_bis1 = MulAddMul {ais1,false,typeof(alpha),Bool} (alpha,true )
512
+ __muldiag_nonzeroalpha! (out, D1, D2, _add_bis1)
475
513
end
476
514
return out
477
515
end
@@ -658,31 +696,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
658
696
@eval $ fun (A:: $Tri , D:: Diagonal ) = $ Tri ($ fun (A. data, D))
659
697
@eval $ fun (A:: $UTri , D:: Diagonal ) = $ Tri (_setdiag! ($ fun (A. data, D), $ f, D. diag))
660
698
end
699
+ @eval * (A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
700
+ @invoke * (A:: AbstractMatrix , D:: Diagonal )
701
+ @eval * (A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
702
+ @invoke * (A:: AbstractMatrix , D:: Diagonal )
661
703
for (fun, f) in zip ((:* , :lmul! , :ldiv! , :\ ), (:identity , :identity , :inv , :inv ))
662
704
@eval $ fun (D:: Diagonal , A:: $Tri ) = $ Tri ($ fun (D, A. data))
663
705
@eval $ fun (D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! ($ fun (D, A. data), $ f, D. diag))
664
706
end
707
+ @eval * (D:: Diagonal , A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
708
+ @invoke * (D:: Diagonal , A:: AbstractMatrix )
709
+ @eval * (D:: Diagonal , A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
710
+ @invoke * (D:: Diagonal , A:: AbstractMatrix )
665
711
# 3-arg ldiv!
666
712
@eval ldiv! (C:: $Tri , D:: Diagonal , A:: $Tri ) = $ Tri (ldiv! (C. data, D, A. data))
667
713
@eval ldiv! (C:: $Tri , D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! (ldiv! (C. data, D, A. data), inv, D. diag))
668
- # 3-arg mul! is disambiguated in special.jl
669
- # 5-arg mul!
670
- @eval _mul! (C:: $Tri , D:: Diagonal , A:: $Tri , _add) = $ Tri (mul! (C. data, D, A. data, _add. alpha, _add. beta))
671
- @eval function _mul! (C:: $Tri , D:: Diagonal , A:: $UTri , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
672
- α, β = _add. alpha, _add. beta
673
- iszero (α) && return _rmul_or_fill! (C, β)
674
- diag′ = bis0 ? nothing : diag (C)
675
- data = mul! (C. data, D, A. data, α, β)
676
- $ Tri (_setdiag! (data, _add, D. diag, diag′))
677
- end
678
- @eval _mul! (C:: $Tri , A:: $Tri , D:: Diagonal , _add) = $ Tri (mul! (C. data, A. data, D, _add. alpha, _add. beta))
679
- @eval function _mul! (C:: $Tri , A:: $UTri , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
680
- α, β = _add. alpha, _add. beta
681
- iszero (α) && return _rmul_or_fill! (C, β)
682
- diag′ = bis0 ? nothing : diag (C)
683
- data = mul! (C. data, A. data, D, α, β)
684
- $ Tri (_setdiag! (data, _add, D. diag, diag′))
685
- end
686
714
end
687
715
688
716
@inline function kron! (C:: AbstractMatrix , A:: Diagonal , B:: Diagonal )
0 commit comments