Skip to content

Commit 09f9050

Browse files
committed
add dim-checker helper, resolve type instability, tests
1 parent 0e78eb7 commit 09f9050

File tree

3 files changed

+51
-23
lines changed

3 files changed

+51
-23
lines changed

src/composition.jl

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
# helper function
2+
check_dim_mul(A, B) = size(A, 2) == size(B, 1)
3+
14
struct CompositeMap{T, As<:Tuple{Vararg{LinearMap}}} <: LinearMap{T}
25
maps::As # stored in order of application to vector
36
function CompositeMap{T, As}(maps::As) where {T, As}
47
N = length(maps)
58
for n in 2:N
6-
size(maps[n], 2) == size(maps[n-1], 1) || throw(DimensionMismatch("CompositeMap"))
9+
check_dim_mul(maps[n], maps[n-1]) || throw(DimensionMismatch("CompositeMap"))
710
end
811
for n in 1:N
912
promote_type(T, eltype(maps[n])) == T || throw(InexactError())
@@ -71,22 +74,22 @@ Base.:(-)(A::LinearMap) = -1 * A
7174

7275
# composition of linear maps
7376
function Base.:(*)(A₁::CompositeMap, A₂::CompositeMap)
74-
size(A₁, 2) == size(A₂, 1) || throw(DimensionMismatch("*"))
77+
check_dim_mul(A₁, A₂) || throw(DimensionMismatch("*"))
7578
T = promote_type(eltype(A₁), eltype(A₂))
7679
return CompositeMap{T}(tuple(A₂.maps..., A₁.maps...))
7780
end
7881
function Base.:(*)(A₁::LinearMap, A₂::CompositeMap)
79-
size(A₁, 2) == size(A₂, 1) || throw(DimensionMismatch("*"))
82+
check_dim_mul(A₁, A₂) || throw(DimensionMismatch("*"))
8083
T = promote_type(eltype(A₁), eltype(A₂))
8184
return CompositeMap{T}(tuple(A₂.maps..., A₁))
8285
end
8386
function Base.:(*)(A₁::CompositeMap, A₂::LinearMap)
84-
size(A₁, 2) == size(A₂, 1) || throw(DimensionMismatch("*"))
87+
check_dim_mul(A₁, A₂) || throw(DimensionMismatch("*"))
8588
T = promote_type(eltype(A₁), eltype(A₂))
8689
return CompositeMap{T}(tuple(A₂, A₁.maps...))
8790
end
8891
function Base.:(*)(A₁::LinearMap, A₂::LinearMap)
89-
size(A₁, 2) == size(A₂, 1) || throw(DimensionMismatch("*"))
92+
check_dim_mul(A₁, A₂) || throw(DimensionMismatch("*"))
9093
T = promote_type(eltype(A₁), eltype(A₂))
9194
return CompositeMap{T}(tuple(A₂, A₁))
9295
end

src/kronecker.jl

+31-13
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,13 @@ julia> Matrix(Δ)
4141
0 1 1 -4
4242
```
4343
"""
44-
# Base.kron(A::LinearMap, B::LinearMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}((A, B))
45-
# Base.kron(A::LinearMap{TA}, B::LinearMap{TB}) where {TA,TB} = KroneckerMap{promote_type(TA,TB)}((A, B))
46-
Base.kron(As::LinearMap...) = KroneckerMap{promote_type(map(eltype, As)...)}(tuple(As...))
47-
Base.kron(A::LinearMap, B::AbstractArray) = kron(A, LinearMap(B))
48-
Base.kron(A::AbstractArray, B::LinearMap) = kron(LinearMap(A), B)
44+
Base.kron(A::LinearMap, B::LinearMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}((A, B))
45+
Base.kron(A::LinearMap, B::KroneckerMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A, B.maps...))
46+
Base.kron(A::KroneckerMap, B::LinearMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B))
47+
Base.kron(A::KroneckerMap, B::KroneckerMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B.maps...))
48+
Base.kron(A::LinearMap, B::LinearMap, Cs::LinearMap...) = KroneckerMap{promote_type(eltype(A), eltype(B), map(eltype, Cs)...)}(tuple(A, B, Cs...))
49+
Base.kron(A::AbstractMatrix, B::LinearMap) = kron(LinearMap(A), B)
50+
Base.kron(A::LinearMap, B::AbstractMatrix) = kron(A, LinearMap(B))
4951
# promote AbstractMatrix arguments to LinearMaps, then take LinearMap-Kronecker product
5052
for k in 3:8 # is 8 sufficient?
5153
Is = ntuple(n->:($(Symbol(:A,n))::AbstractMatrix), Val(k-1))
@@ -75,14 +77,6 @@ LinearAlgebra.ishermitian(A::KroneckerMap) = all(ishermitian, A.maps)
7577
LinearAlgebra.adjoint(A::KroneckerMap{T}) where {T} = KroneckerMap{T}(map(adjoint, A.maps))
7678
LinearAlgebra.transpose(A::KroneckerMap{T}) where {T} = KroneckerMap{T}(map(transpose, A.maps))
7779

78-
function Base.:(*)(A::KroneckerMap, B::KroneckerMap)
79-
if length(A.maps) == length(B.maps) && all(M -> size(M[1], 2) == size(M[2], 1), zip(A.maps, B.maps))
80-
return kron(map(prod, zip(A.maps, B.maps))...)
81-
else
82-
return CompositeMap{T}(tuple(B, A))
83-
end
84-
end
85-
8680
Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps == B.maps)
8781

8882
function LinearMaps.A_mul_B!(y::AbstractVector, L::KroneckerMap{T,<:NTuple{2,LinearMap}}, x::AbstractVector) where {T}
@@ -102,6 +96,30 @@ function LinearMaps.A_mul_B!(y::AbstractVector, L::KroneckerMap{T}, x::AbstractV
10296
_kronmul!(y, B, X, transpose(A), T)
10397
return y
10498
end
99+
# mixed-product rule, prefer the right if possible
100+
# (A₁ ⊗ A₂ ⊗ ... ⊗ Aᵣ) * (B₁ ⊗ B₂ ⊗ ... ⊗ Bᵣ) = (A₁B₁) ⊗ (A₂B₂) ⊗ ... ⊗ (AᵣBᵣ)
101+
function A_mul_B!(y::AbstractVector, L::CompositeMap{<:Any,<:Tuple{KroneckerMap,KroneckerMap}}, x::AbstractVector)
102+
B, A = L.maps
103+
if length(A.maps) == length(B.maps) && all(M -> check_dim_mul(M[1], M[2]), zip(A.maps, B.maps))
104+
A_mul_B!(y, kron(map(prod, zip(A.maps, B.maps))...), x)
105+
else
106+
A_mul_B!(y, LinearMap(A)*B, x)
107+
end
108+
end
109+
# mixed-product rule, prefer the right if possible
110+
# (A₁ ⊗ B₁)*(A₂⊗B₂)*...*(Aᵣ⊗Bᵣ) = (A₁*A₂*...*Aᵣ) ⊗ (B₁*B₂*...*Bᵣ)
111+
function A_mul_B!(y::AbstractVector, L::CompositeMap{T,<:Tuple{Vararg{KroneckerMap{<:Any,<:Tuple{LinearMap,LinearMap}}}}}, x::AbstractVector) where {T}
112+
As = map(AB -> AB.maps[1], L.maps)
113+
Bs = map(AB -> AB.maps[2], L.maps)
114+
As1, As2 = Base.front(As), Base.tail(As)
115+
Bs1, Bs2 = Base.front(Bs), Base.tail(Bs)
116+
apply = all(A -> check_dim_mul(A...), zip(As1, As2)) && all(A -> check_dim_mul(A...), zip(Bs1, Bs2))
117+
if apply
118+
A_mul_B!(y, kron(prod(As), prod(Bs)), x)
119+
else
120+
A_mul_B!(y, CompositeMap{T}(map(LinearMap, L.maps)), x)
121+
end
122+
end
105123

106124
function _kronmul!(y, B, X, At, T)
107125
na, ma = size(At)

test/kronecker.jl

+12-5
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,23 @@ using Test, LinearMaps, LinearAlgebra
2222
K = @inferred kron(A, A, A, LA)
2323
@test K isa LinearMaps.KroneckerMap
2424
@test Matrix(K) kron(A, A, A, A)
25-
@test K*K isa LinearMaps.KroneckerMap{ComplexF64,<:Tuple{Vararg{LinearMaps.CompositeMap}}}
25+
@test Matrix(@inferred K*K) kron(A, A, A, A)*kron(A, A, A, A)
2626
K4 = @inferred kron(A, B, B, LB)
2727
# check that matrices don't get Kronecker-multiplied, but that all is lazy
2828
@test K4.maps[1].lmap === A
2929
@test @inferred kron(LA, LB)' == @inferred kron(LA', LB')
30-
@test kron(LA, B) == kron(LA, LB) == kron(A, LB)
31-
@test ishermitian(kron(LA'LA, LB'LB))
30+
@test (@inferred kron(LA, B)) == (@inferred kron(LA, LB)) == (@inferred kron(A, LB))
31+
@test @inferred ishermitian(kron(LA'LA, LB'LB))
3232
A = rand(3, 3); B = rand(2, 2); LA = LinearMap(A); LB = LinearMap(B)
33-
@test issymmetric(kron(LA'LA, LB'LB))
34-
@test ishermitian(kron(LA'LA, LB'LB))
33+
@test @inferred issymmetric(kron(LA'LA, LB'LB))
34+
@test @inferred ishermitian(kron(LA'LA, LB'LB))
35+
# use mixed-product rule
36+
K = kron(LA, LB) * kron(LA, LB) * kron(LA, LB)
37+
@test Matrix(K) kron(A, B)^3
38+
# example that doesn't use mixed-product rule
39+
A = rand(3, 2); B = rand(2, 3)
40+
K = @inferred kron(A, LinearMap(B))
41+
@test Matrix(@inferred K*K) kron(A, B)*kron(A, B)
3542
end
3643
@testset "Kronecker sum" begin
3744
A = rand(ComplexF64, 3, 3)

0 commit comments

Comments
 (0)