Skip to content

Commit 601e535

Browse files
mcabbottfredrikekremartinholters
committed
Add dot(x,A,y) (#683)
* add some dot(x,A,y) methods * Bump version to 3.2.0 Co-authored-by: Fredrik Ekre <[email protected]> Co-authored-by: Martin Holters <[email protected]>
1 parent 310219d commit 601e535

File tree

4 files changed

+183
-1
lines changed

4 files changed

+183
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Compat"
22
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
3-
version = "3.1.0"
3+
version = "3.2.0"
44

55
[deps]
66
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ Please check the list below for the specific syntax you need.
4444

4545
## Supported features
4646

47+
* `dot` now has a 3-argument method `dot(x, A, y)` without storing the intermediate result `A*y` ([#32739]). (since Compat 3.2.0)
48+
4749
* `pkgdir(m)` returns the root directory of the package that imported module `m` ([#33128]). (since Compat 3.2.0)
4850

4951
* `filter` can now act on a `Tuple` [#32968]. (since Compat 3.1.0)
@@ -104,6 +106,7 @@ Note that you should specify the correct minimum version for `Compat` in the
104106
[#29674]: https://github.com/JuliaLang/julia/issues/29674
105107
[#29749]: https://github.com/JuliaLang/julia/issues/29749
106108
[#32628]: https://github.com/JuliaLang/julia/issues/32628
109+
[#32739]: https://github.com/JuliaLang/julia/pull/32739
107110
[#33129]: https://github.com/JuliaLang/julia/issues/33129
108111
[#33568]: https://github.com/JuliaLang/julia/pull/33568
109112
[#33128]: https://github.com/JuliaLang/julia/pull/33128

src/Compat.jl

+84
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module Compat
22

3+
import LinearAlgebra
4+
using LinearAlgebra: Adjoint, Diagonal, Transpose, UniformScaling, RealHermSymComplexHerm
5+
36
include("compatmacro.jl")
47

58
# https://github.com/JuliaLang/julia/pull/29679
@@ -88,6 +91,87 @@ if VERSION < v"1.3.0-alpha.8"
8891
Base.mod(i::Integer, r::AbstractUnitRange{<:Integer}) = mod(i-first(r), length(r)) + first(r)
8992
end
9093

94+
# https://github.com/JuliaLang/julia/pull/32739
95+
# This omits special methods for more exotic matrix types, Triangular and worse.
96+
if VERSION < v"1.4.0-DEV.92" # 2425ae760fb5151c5c7dd0554e87c5fc9e24de73
97+
98+
# stdlib/LinearAlgebra/src/generic.jl
99+
LinearAlgebra.dot(x, A, y) = LinearAlgebra.dot(x, A*y) # generic fallback
100+
101+
function LinearAlgebra.dot(x::AbstractVector, A::AbstractMatrix, y::AbstractVector)
102+
(axes(x)..., axes(y)...) == axes(A) || throw(DimensionMismatch())
103+
T = typeof(LinearAlgebra.dot(first(x), first(A), first(y)))
104+
s = zero(T)
105+
i₁ = first(eachindex(x))
106+
x₁ = first(x)
107+
@inbounds for j in eachindex(y)
108+
yj = y[j]
109+
if !iszero(yj)
110+
temp = zero(adjoint(A[i₁,j]) * x₁)
111+
@simd for i in eachindex(x)
112+
temp += adjoint(A[i,j]) * x[i]
113+
end
114+
s += LinearAlgebra.dot(temp, yj)
115+
end
116+
end
117+
return s
118+
end
119+
LinearAlgebra.dot(x::AbstractVector, adjA::Adjoint, y::AbstractVector) =
120+
adjoint(LinearAlgebra.dot(y, adjA.parent, x))
121+
LinearAlgebra.dot(x::AbstractVector, transA::Transpose{<:Real}, y::AbstractVector) =
122+
adjoint(LinearAlgebra.dot(y, transA.parent, x))
123+
124+
# stdlib/LinearAlgebra/src/diagonal.jl
125+
function LinearAlgebra.dot(x::AbstractVector, D::Diagonal, y::AbstractVector)
126+
mapreduce(t -> LinearAlgebra.dot(t[1], t[2], t[3]), +, zip(x, D.diag, y))
127+
end
128+
129+
# stdlib/LinearAlgebra/src/symmetric.jl
130+
function LinearAlgebra.dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
131+
require_one_based_indexing(x, y)
132+
(length(x) == length(y) == size(A, 1)) || throw(DimensionMismatch())
133+
data = A.data
134+
r = zero(eltype(x)) * zero(eltype(A)) * zero(eltype(y))
135+
if A.uplo == 'U'
136+
@inbounds for j = 1:length(y)
137+
r += LinearAlgebra.dot(x[j], real(data[j,j]), y[j])
138+
@simd for i = 1:j-1
139+
Aij = data[i,j]
140+
r += LinearAlgebra.dot(x[i], Aij, y[j]) +
141+
LinearAlgebra.dot(x[j], adjoint(Aij), y[i])
142+
end
143+
end
144+
else # A.uplo == 'L'
145+
@inbounds for j = 1:length(y)
146+
r += LinearAlgebra.dot(x[j], real(data[j,j]), y[j])
147+
@simd for i = j+1:length(y)
148+
Aij = data[i,j]
149+
r += LinearAlgebra.dot(x[i], Aij, y[j]) +
150+
LinearAlgebra.dot(x[j], adjoint(Aij), y[i])
151+
end
152+
end
153+
end
154+
return r
155+
end
156+
157+
# stdlib/LinearAlgebra/src/uniformscaling.jl
158+
LinearAlgebra.dot(x::AbstractVector, J::UniformScaling, y::AbstractVector) =
159+
LinearAlgebra.dot(x, J.λ, y)
160+
LinearAlgebra.dot(x::AbstractVector, a::Number, y::AbstractVector) =
161+
sum(t -> LinearAlgebra.dot(t[1], a, t[2]), zip(x, y))
162+
LinearAlgebra.dot(x::AbstractVector, a::Union{Real,Complex}, y::AbstractVector) =
163+
a*LinearAlgebra.dot(x, y)
164+
end
165+
166+
# https://github.com/JuliaLang/julia/pull/30630
167+
if VERSION < v"1.2.0-DEV.125" # 1da48c2e4028c1514ed45688be727efbef1db884
168+
require_one_based_indexing(A...) = !Base.has_offset_axes(A...) || throw(ArgumentError(
169+
"offset arrays are not supported but got an array with index other than 1"))
170+
# At present this is only used in Compat inside the above dot(x,A,y) functions, #32739
171+
elseif VERSION < v"1.4.0-DEV.92"
172+
using Base: require_one_based_indexing
173+
end
174+
91175
# https://github.com/JuliaLang/julia/pull/33568
92176
if VERSION < v"1.4.0-DEV.329"
93177
Base.:(f, g, h...) = (f g, h...)

test/runtests.jl

+95
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,101 @@ end
9090
@test_throws DivideError mod(3, 1:0)
9191
end
9292

93+
using LinearAlgebra
94+
95+
@testset "generalized dot #32739" begin
96+
# stdlib/LinearAlgebra/test/generic.jl
97+
for elty in (Int, Float32, Float64, BigFloat, Complex{Float32}, Complex{Float64}, Complex{BigFloat})
98+
n = 10
99+
if elty <: Int
100+
A = rand(-n:n, n, n)
101+
x = rand(-n:n, n)
102+
y = rand(-n:n, n)
103+
elseif elty <: Real
104+
A = convert(Matrix{elty}, randn(n,n))
105+
x = rand(elty, n)
106+
y = rand(elty, n)
107+
else
108+
A = convert(Matrix{elty}, complex.(randn(n,n), randn(n,n)))
109+
x = rand(elty, n)
110+
y = rand(elty, n)
111+
end
112+
@test dot(x, A, y) dot(A'x, y) *(x', A, y) (x'A)*y
113+
@test dot(x, A', y) dot(A*x, y) *(x', A', y) (x'A')*y
114+
elty <: Real && @test dot(x, transpose(A), y) dot(x, transpose(A)*y) *(x', transpose(A), y) (x'*transpose(A))*y
115+
B = reshape([A], 1, 1)
116+
x = [x]
117+
y = [y]
118+
@test dot(x, B, y) dot(B'x, y)
119+
@test dot(x, B', y) dot(B*x, y)
120+
elty <: Real && @test dot(x, transpose(B), y) dot(x, transpose(B)*y)
121+
end
122+
123+
# stdlib/LinearAlgebra/test/symmetric.jl
124+
n = 10
125+
areal = randn(n,n)/2
126+
aimg = randn(n,n)/2
127+
@testset for eltya in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Int)
128+
a = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(areal, aimg) : areal)
129+
asym = transpose(a) + a # symmetric indefinite
130+
aherm = a' + a # Hermitian indefinite
131+
apos = a' * a # Hermitian positive definite
132+
aposs = apos + transpose(apos) # Symmetric positive definite
133+
ε = εa = eps(abs(float(one(eltya))))
134+
x = randn(n)
135+
y = randn(n)
136+
b = randn(n,n)/2
137+
x = eltya == Int ? rand(1:7, n) : convert(Vector{eltya}, eltya <: Complex ? complex.(x, zeros(n)) : x)
138+
y = eltya == Int ? rand(1:7, n) : convert(Vector{eltya}, eltya <: Complex ? complex.(y, zeros(n)) : y)
139+
b = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(b, zeros(n,n)) : b)
140+
141+
@testset "generalized dot product" begin
142+
for uplo in (:U, :L)
143+
@test dot(x, Hermitian(aherm, uplo), y) dot(x, Hermitian(aherm, uplo)*y) dot(x, Matrix(Hermitian(aherm, uplo)), y)
144+
@test dot(x, Hermitian(aherm, uplo), x) dot(x, Hermitian(aherm, uplo)*x) dot(x, Matrix(Hermitian(aherm, uplo)), x)
145+
end
146+
if eltya <: Real
147+
for uplo in (:U, :L)
148+
@test dot(x, Symmetric(aherm, uplo), y) dot(x, Symmetric(aherm, uplo)*y) dot(x, Matrix(Symmetric(aherm, uplo)), y)
149+
@test dot(x, Symmetric(aherm, uplo), x) dot(x, Symmetric(aherm, uplo)*x) dot(x, Matrix(Symmetric(aherm, uplo)), x)
150+
end
151+
end
152+
end
153+
end
154+
155+
# stdlib/LinearAlgebra/test/uniformscaling.jl
156+
@testset "generalized dot" begin
157+
x = rand(-10:10, 3)
158+
y = rand(-10:10, 3)
159+
λ = rand(-10:10)
160+
J = UniformScaling(λ)
161+
@test dot(x, J, y) == λ*dot(x, y)
162+
end
163+
164+
# stdlib/LinearAlgebra/test/bidiag.jl
165+
# The special method for this is not in Compat #683, so this tests the generic fallback
166+
@testset "generalized dot" begin
167+
for elty in (Float64, ComplexF64)
168+
dv = randn(elty, 5)
169+
ev = randn(elty, 4)
170+
x = randn(elty, 5)
171+
y = randn(elty, 5)
172+
for uplo in (:U, :L)
173+
B = Bidiagonal(dv, ev, uplo)
174+
@test dot(x, B, y) dot(B'x, y) dot(x, Matrix(B), y)
175+
end
176+
end
177+
end
178+
179+
# Diagonal -- no such test in Base.
180+
@testset "diagonal" begin
181+
x = rand(-10:10, 3) .+ im
182+
y = rand(-10:10, 3) .+ im
183+
d = Diagonal(rand(-10:10, 3) .+ im)
184+
@test dot(x,d,y) == dot(x,collect(d),y) == dot(x, d*y)
185+
end
186+
end
187+
93188
# https://github.com/JuliaLang/julia/pull/33568
94189
@testset "function composition" begin
95190
@test (x -> x-2, x -> x-3, x -> x+5)(7) == 7

0 commit comments

Comments
 (0)