Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e0c7d41

Browse files
committedJan 9, 2018
Make Adjoint/Transpose behave like typical constructors.
1 parent 29dc306 commit e0c7d41

File tree

2 files changed

+66
-70
lines changed

2 files changed

+66
-70
lines changed
 

‎base/linalg/adjtrans.jl

+50-54
Original file line numberDiff line numberDiff line change
@@ -11,46 +11,42 @@ import Base: length, size, axes, IndexStyle, getindex, setindex!, parent, vec, c
1111
struct Adjoint{T,S} <: AbstractMatrix{T}
1212
parent::S
1313
function Adjoint{T,S}(A::S) where {T,S}
14-
checkeltype(Adjoint, T, eltype(A))
14+
checkeltype_adjoint(T, eltype(A))
1515
new(A)
1616
end
1717
end
1818
struct Transpose{T,S} <: AbstractMatrix{T}
1919
parent::S
2020
function Transpose{T,S}(A::S) where {T,S}
21-
checkeltype(Transpose, T, eltype(A))
21+
checkeltype_transpose(T, eltype(A))
2222
new(A)
2323
end
2424
end
2525

26-
function checkeltype(::Type{Transform}, ::Type{ResultEltype}, ::Type{ParentEltype}) where {Transform, ResultEltype, ParentEltype}
27-
if ResultEltype !== transformtype(Transform, ParentEltype)
28-
error(string("Element type mismatch. Tried to create an `$Transform{$ResultEltype}` ",
29-
"from an object with eltype `$ParentEltype`, but the element type of the ",
30-
"`$Transform` of an object with eltype `$ParentEltype` must be ",
31-
"`$(transformtype(Transform, ParentEltype))`"))
32-
end
26+
function checkeltype_adjoint(::Type{ResultEltype}, ::Type{ParentEltype}) where {ResultEltype,ParentEltype}
27+
ResultEltype === Base.promote_op(adjoint, ParentEltype) || error(string(
28+
"Element type mismatch. Tried to create an `Adjoint{$ResultEltype}` ",
29+
"from an object with eltype `$ParentEltype`, but the element type of ",
30+
"the adjoint of an object with eltype `$ParentEltype` must be ",
31+
"`$(Base.promote_op(adjoint, ParentEltype))`."))
3332
return nothing
3433
end
35-
function transformtype(::Type{O}, ::Type{S}) where {O,S}
36-
# similar to promote_op(::Any, ::Type)
37-
@_inline_meta
38-
T = _return_type(O, Tuple{_default_type(S)})
39-
_isleaftype(S) && return _isleaftype(T) ? T : Any
40-
return typejoin(S, T)
34+
function checkeltype_transpose(::Type{ResultEltype}, ::Type{ParentEltype}) where {ResultEltype,ParentEltype}
35+
ResultEltype === Base.promote_op(transpose, ParentEltype) || error(string(
36+
"Element type mismatch. Tried to create a `Transpose{$ResultEltype}` ",
37+
"from an object with eltype `$ParentEltype`, but the element type of ",
38+
"the transpose of an object with eltype `$ParentEltype` must be ",
39+
"`$(Base.promote_op(transpose, ParentEltype))`."))
40+
return nothing
4141
end
4242

4343
# basic outer constructors
44-
Adjoint(A) = Adjoint{transformtype(Adjoint,eltype(A)),typeof(A)}(A)
45-
Transpose(A) = Transpose{transformtype(Transpose,eltype(A)),typeof(A)}(A)
46-
47-
# numbers are the end of the line
48-
Adjoint(x::Number) = adjoint(x)
49-
Transpose(x::Number) = transpose(x)
44+
Adjoint(A) = Adjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
45+
Transpose(A) = Transpose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)
5046

51-
# unwrapping constructors
52-
Adjoint(A::Adjoint) = A.parent
53-
Transpose(A::Transpose) = A.parent
47+
# no-op constructors for already-wrapped objects
48+
Adjoint(A::Adjoint) = A
49+
Transpose(A::Transpose) = A
5450

5551
# wrapping lowercase quasi-constructors
5652
adjoint(A::AbstractVecOrMat) = Adjoint(A)
@@ -80,6 +76,7 @@ julia> transpose(A)
8076
```
8177
"""
8278
transpose(A::AbstractVecOrMat) = Transpose(A)
79+
8380
# unwrapping lowercase quasi-constructors
8481
adjoint(A::Adjoint) = A.parent
8582
transpose(A::Transpose) = A.parent
@@ -95,10 +92,8 @@ const AdjOrTransAbsVec{T} = AdjOrTrans{T,<:AbstractVector}
9592
const AdjOrTransAbsMat{T} = AdjOrTrans{T,<:AbstractMatrix}
9693

9794
# for internal use below
98-
wrappertype(A::Adjoint) = Adjoint
99-
wrappertype(A::Transpose) = Transpose
100-
wrappertype(::Type{<:Adjoint}) = Adjoint
101-
wrappertype(::Type{<:Transpose}) = Transpose
95+
wrapperop(A::Adjoint) = adjoint
96+
wrapperop(A::Transpose) = transpose
10297

10398
# AbstractArray interface, basic definitions
10499
length(A::AdjOrTrans) = length(A.parent)
@@ -108,22 +103,22 @@ axes(v::AdjOrTransAbsVec) = (Base.OneTo(1), axes(v.parent)...)
108103
axes(A::AdjOrTransAbsMat) = reverse(axes(A.parent))
109104
IndexStyle(::Type{<:AdjOrTransAbsVec}) = IndexLinear()
110105
IndexStyle(::Type{<:AdjOrTransAbsMat}) = IndexCartesian()
111-
@propagate_inbounds getindex(v::AdjOrTransAbsVec, i::Int) = wrappertype(v)(v.parent[i])
112-
@propagate_inbounds getindex(A::AdjOrTransAbsMat, i::Int, j::Int) = wrappertype(A)(A.parent[j, i])
113-
@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrappertype(v)(x), i); v)
114-
@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, wrappertype(A)(x), j, i); A)
106+
@propagate_inbounds getindex(v::AdjOrTransAbsVec, i::Int) = wrapperop(v)(v.parent[i])
107+
@propagate_inbounds getindex(A::AdjOrTransAbsMat, i::Int, j::Int) = wrapperop(A)(A.parent[j, i])
108+
@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrapperop(v)(x), i); v)
109+
@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, wrapperop(A)(x), j, i); A)
115110
# AbstractArray interface, additional definitions to retain wrapper over vectors where appropriate
116-
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrappertype(v)(v.parent[is])
117-
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrappertype(v)(v.parent[:])
111+
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrapperop(v)(v.parent[is])
112+
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrapperop(v)(v.parent[:])
118113

119114
# conversion of underlying storage
120115
convert(::Type{Adjoint{T,S}}, A::Adjoint) where {T,S} = Adjoint{T,S}(convert(S, A.parent))
121116
convert(::Type{Transpose{T,S}}, A::Transpose) where {T,S} = Transpose{T,S}(convert(S, A.parent))
122117

123118
# for vectors, the semantics of the wrapped and unwrapped types differ
124119
# so attempt to maintain both the parent and wrapper type insofar as possible
125-
similar(A::AdjOrTransAbsVec) = wrappertype(A)(similar(A.parent))
126-
similar(A::AdjOrTransAbsVec, ::Type{T}) where {T} = wrappertype(A)(similar(A.parent, transformtype(wrappertype(A), T)))
120+
similar(A::AdjOrTransAbsVec) = wrapperop(A)(similar(A.parent))
121+
similar(A::AdjOrTransAbsVec, ::Type{T}) where {T} = wrapperop(A)(similar(A.parent, Base.promote_op(wrapperop(A), T)))
127122
# for matrices, the semantics of the wrapped and unwrapped types are generally the same
128123
# and as you are allocating with similar anyway, you might as well get something unwrapped
129124
similar(A::AdjOrTrans) = similar(A.parent, eltype(A), size(A))
@@ -142,30 +137,31 @@ isless(A::AdjOrTransAbsVec, B::AdjOrTransAbsVec) = isless(parent(A), parent(B))
142137
# to retain the associated semantics post-concatenation
143138
hcat(avs::Union{Number,AdjointAbsVec}...) = _adjoint_hcat(avs...)
144139
hcat(tvs::Union{Number,TransposeAbsVec}...) = _transpose_hcat(tvs...)
145-
_adjoint_hcat(avs::Union{Number,AdjointAbsVec}...) = Adjoint(vcat(map(Adjoint, avs)...))
146-
_transpose_hcat(tvs::Union{Number,TransposeAbsVec}...) = Transpose(vcat(map(Transpose, tvs)...))
147-
typed_hcat(::Type{T}, avs::Union{Number,AdjointAbsVec}...) where {T} = Adjoint(typed_vcat(T, map(Adjoint, avs)...))
148-
typed_hcat(::Type{T}, tvs::Union{Number,TransposeAbsVec}...) where {T} = Transpose(typed_vcat(T, map(Transpose, tvs)...))
140+
_adjoint_hcat(avs::Union{Number,AdjointAbsVec}...) = adjoint(vcat(map(adjoint, avs)...))
141+
_transpose_hcat(tvs::Union{Number,TransposeAbsVec}...) = transpose(vcat(map(transpose, tvs)...))
142+
typed_hcat(::Type{T}, avs::Union{Number,AdjointAbsVec}...) where {T} = adjoint(typed_vcat(T, map(adjoint, avs)...))
143+
typed_hcat(::Type{T}, tvs::Union{Number,TransposeAbsVec}...) where {T} = transpose(typed_vcat(T, map(transpose, tvs)...))
149144
# otherwise-redundant definitions necessary to prevent hitting the concat methods in sparse/sparsevector.jl
150145
hcat(avs::Adjoint{<:Any,<:Vector}...) = _adjoint_hcat(avs...)
151146
hcat(tvs::Transpose{<:Any,<:Vector}...) = _transpose_hcat(tvs...)
152147
hcat(avs::Adjoint{T,Vector{T}}...) where {T} = _adjoint_hcat(avs...)
153148
hcat(tvs::Transpose{T,Vector{T}}...) where {T} = _transpose_hcat(tvs...)
149+
# TODO unify and allow mixed combinations
154150

155151

156152
### higher order functions
157153
# preserve Adjoint/Transpose wrapper around vectors
158154
# to retain the associated semantics post-map/broadcast
159155
#
160156
# note that the caller's operation f operates in the domain of the wrapped vectors' entries.
161-
# hence the Adjoint->f->Adjoint shenanigans applied to the parent vectors' entries.
162-
map(f, avs::AdjointAbsVec...) = Adjoint(map((xs...) -> Adjoint(f(Adjoint.(xs)...)), parent.(avs)...))
163-
map(f, tvs::TransposeAbsVec...) = Transpose(map((xs...) -> Transpose(f(Transpose.(xs)...)), parent.(tvs)...))
157+
# hence the adjoint->f->adjoint shenanigans applied to the parent vectors' entries.
158+
map(f, avs::AdjointAbsVec...) = adjoint(map((xs...) -> adjoint(f(adjoint.(xs)...)), parent.(avs)...))
159+
map(f, tvs::TransposeAbsVec...) = transpose(map((xs...) -> transpose(f(transpose.(xs)...)), parent.(tvs)...))
164160
quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in the defs below
165161
quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below
166-
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = Adjoint(broadcast((xs...) -> Adjoint(f(Adjoint.(xs)...)), quasiparenta.(avs)...))
167-
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = Transpose(broadcast((xs...) -> Transpose(f(Transpose.(xs)...)), quasiparentt.(tvs)...))
168-
162+
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
163+
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
164+
# TODO unify and allow mixed combinations
169165

170166
### linear algebra
171167

@@ -186,11 +182,11 @@ end
186182
*(u::TransposeAbsVec, v::TransposeAbsVec) = throw(MethodError(*, (u, v)))
187183

188184
# Adjoint/Transpose-vector * matrix
189-
*(u::AdjointAbsVec, A::AbstractMatrix) = Adjoint(Adjoint(A) * u.parent)
190-
*(u::TransposeAbsVec, A::AbstractMatrix) = Transpose(Transpose(A) * u.parent)
185+
*(u::AdjointAbsVec, A::AbstractMatrix) = adjoint(adjoint(A) * u.parent)
186+
*(u::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A) * u.parent)
191187
# Adjoint/Transpose-vector * Adjoint/Transpose-matrix
192-
*(u::AdjointAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = Adjoint(A.parent * u.parent)
193-
*(u::TransposeAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = Transpose(A.parent * u.parent)
188+
*(u::AdjointAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = adjoint(A.parent * u.parent)
189+
*(u::TransposeAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = transpose(A.parent * u.parent)
194190

195191

196192
## pseudoinversion
@@ -203,10 +199,10 @@ pinv(v::TransposeAbsVec, tol::Real = 0) = pinv(conj(v.parent)).parent
203199

204200

205201
## right-division \
206-
/(u::AdjointAbsVec, A::AbstractMatrix) = Adjoint(Adjoint(A) \ u.parent)
207-
/(u::TransposeAbsVec, A::AbstractMatrix) = Transpose(Transpose(A) \ u.parent)
208-
/(u::AdjointAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = Adjoint(conj(A.parent) \ u.parent) # technically should be Adjoint(copy(Adjoint(copy(A))) \ u.parent)
209-
/(u::TransposeAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = Transpose(conj(A.parent) \ u.parent) # technically should be Transpose(copy(Transpose(copy(A))) \ u.parent)
202+
/(u::AdjointAbsVec, A::AbstractMatrix) = adjoint(adjoint(A) \ u.parent)
203+
/(u::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A) \ u.parent)
204+
/(u::AdjointAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = adjoint(conj(A.parent) \ u.parent) # technically should be adjoint(copy(adjoint(copy(A))) \ u.parent)
205+
/(u::TransposeAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = transpose(conj(A.parent) \ u.parent) # technically should be transpose(copy(transpose(copy(A))) \ u.parent)
210206

211207
# dismabiguation methods
212208
*(A::AdjointAbsVec, B::Transpose{<:Any,<:AbstractMatrix}) = A * copy(B)

‎test/linalg/adjtrans.jl

+16-16
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,12 @@ end
6262
# the tests for the inner constructors exercise abstract scalar and concrete array eltype, forgoing here
6363
end
6464

65-
@testset "Adjoint and Transpose of Numbers" begin
66-
@test Adjoint(1) == 1
67-
@test Adjoint(1.0) == 1.0
68-
@test Adjoint(1im) == -1im
69-
@test Adjoint(1.0im) == -1.0im
70-
@test Transpose(1) == 1
71-
@test Transpose(1.0) == 1.0
72-
@test Transpose(1im) == 1im
73-
@test Transpose(1.0im) == 1.0im
74-
end
75-
76-
@testset "Adjoint and Transpose unwrapping" begin
65+
@testset "Adjoint and Transpose no-op on already-wrapped objects" begin
7766
intvec, intmat = [1, 2], [1 2; 3 4]
78-
@test Adjoint(Adjoint(intvec)) === intvec
79-
@test Adjoint(Adjoint(intmat)) === intmat
80-
@test Transpose(Transpose(intvec)) === intvec
81-
@test Transpose(Transpose(intmat)) === intmat
67+
@test (A = Adjoint(intvec); Adjoint(A) === A)
68+
@test (A = Adjoint(intmat); Adjoint(A) === A)
69+
@test (A = Transpose(intvec); Transpose(A) === A)
70+
@test (A = Transpose(intmat); Transpose(A) === A)
8271
end
8372

8473
@testset "Adjoint and Transpose basic AbstractArray functionality" begin
@@ -441,6 +430,17 @@ end
441430
end
442431
end
443432

433+
@testset "adjoint and transpose of Numbers" begin
434+
@test adjoint(1) == 1
435+
@test adjoint(1.0) == 1.0
436+
@test adjoint(1im) == -1im
437+
@test adjoint(1.0im) == -1.0im
438+
@test transpose(1) == 1
439+
@test transpose(1.0) == 1.0
440+
@test transpose(1im) == 1im
441+
@test transpose(1.0im) == 1.0im
442+
end
443+
444444
@testset "adjoint!(a, b) return a" begin
445445
a = fill(1.0+im, 5)
446446
b = fill(1.0+im, 1, 5)

0 commit comments

Comments
 (0)
Please sign in to comment.