Skip to content

Commit b01095e

Browse files
authored
Fix kron indexing for types without a unique zero (#56229)
This fixes a bug introduced in #55941. We may also take this opportunity to limit the scope of the `@inbounds` annotations, and also use `axes` to compute the bounds instead of hard-coding them. The real "fix" here is on line 767, where `l in 1:nA` should have been `l in 1:mB`. Using `axes` avoids such errors, and makes the operation safer as well.
1 parent fee8090 commit b01095e

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

+25-25
Original file line numberDiff line numberDiff line change
@@ -700,16 +700,16 @@ end
700700
zerofilled = true
701701
end
702702
end
703-
@inbounds for i = 1:nA, j = 1:nB
703+
for i in eachindex(valA), j in eachindex(valB)
704704
idx = (i-1)*nB+j
705-
C[idx, idx] = valA[i] * valB[j]
705+
@inbounds C[idx, idx] = valA[i] * valB[j]
706706
end
707707
if !zerofilled
708-
for j in 1:nA, i in 1:mA
708+
for j in axes(A,2), i in axes(A,1)
709709
Δrow, Δcol = (i-1)*mB, (j-1)*nB
710-
for k in 1:nB, l in 1:mB
710+
for k in axes(B,2), l in axes(B,1)
711711
i == j && k == l && continue
712-
C[Δrow + l, Δcol + k] = A[i,j] * B[l,k]
712+
@inbounds C[Δrow + l, Δcol + k] = A[i,j] * B[l,k]
713713
end
714714
end
715715
end
@@ -749,24 +749,24 @@ end
749749
end
750750
end
751751
m = 1
752-
@inbounds for j = 1:nA
753-
A_jj = A[j,j]
754-
for k = 1:nB
755-
for l = 1:mB
756-
C[m] = A_jj * B[l,k]
752+
for j in axes(A,2)
753+
A_jj = @inbounds A[j,j]
754+
for k in axes(B,2)
755+
for l in axes(B,1)
756+
@inbounds C[m] = A_jj * B[l,k]
757757
m += 1
758758
end
759759
m += (nA - 1) * mB
760760
end
761761
if !zerofilled
762762
# populate the zero elements
763-
for i in 1:mA
763+
for i in axes(A,1)
764764
i == j && continue
765-
A_ij = A[i, j]
765+
A_ij = @inbounds A[i, j]
766766
Δrow, Δcol = (i-1)*mB, (j-1)*nB
767-
for k in 1:nB, l in 1:nA
768-
B_lk = B[l, k]
769-
C[Δrow + l, Δcol + k] = A_ij * B_lk
767+
for k in axes(B,2), l in axes(B,1)
768+
B_lk = @inbounds B[l, k]
769+
@inbounds C[Δrow + l, Δcol + k] = A_ij * B_lk
770770
end
771771
end
772772
end
@@ -792,23 +792,23 @@ end
792792
end
793793
end
794794
m = 1
795-
@inbounds for j = 1:nA
796-
for l = 1:mB
797-
Bll = B[l,l]
798-
for i = 1:mA
799-
C[m] = A[i,j] * Bll
795+
for j in axes(A,2)
796+
for l in axes(B,1)
797+
Bll = @inbounds B[l,l]
798+
for i in axes(A,1)
799+
@inbounds C[m] = A[i,j] * Bll
800800
m += nB
801801
end
802802
m += 1
803803
end
804804
if !zerofilled
805-
for i in 1:mA
806-
A_ij = A[i, j]
805+
for i in axes(A,1)
806+
A_ij = @inbounds A[i, j]
807807
Δrow, Δcol = (i-1)*mB, (j-1)*nB
808-
for k in 1:nB, l in 1:mB
808+
for k in axes(B,2), l in axes(B,1)
809809
l == k && continue
810-
B_lk = B[l, k]
811-
C[Δrow + l, Δcol + k] = A_ij * B_lk
810+
B_lk = @inbounds B[l, k]
811+
@inbounds C[Δrow + l, Δcol + k] = A_ij * B_lk
812812
end
813813
end
814814
end

stdlib/LinearAlgebra/test/diagonal.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ Random.seed!(1)
353353
D3 = Diagonal(convert(Vector{elty}, rand(n÷2)))
354354
DM3= Matrix(D3)
355355
@test Matrix(kron(D, D3)) kron(DM, DM3)
356-
M4 = rand(elty, n÷2, n÷2)
356+
M4 = rand(elty, size(D3,1) + 1, size(D3,2) + 2) # choose a different size from D3
357357
@test kron(D3, M4) kron(DM3, M4)
358358
@test kron(M4, D3) kron(M4, DM3)
359359
X = [ones(1,1) for i in 1:2, j in 1:2]
@@ -1392,7 +1392,7 @@ end
13921392
end
13931393

13941394
@testset "zeros in kron with block matrices" begin
1395-
D = Diagonal(1:2)
1395+
D = Diagonal(1:4)
13961396
B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2)
13971397
@test kron(D, B) == kron(Array(D), B)
13981398
@test kron(B, D) == kron(B, Array(D))

0 commit comments

Comments
 (0)