Skip to content

Commit 4db5736

Browse files
committed
Porting SDiagonal from Bridge.jl (JuliaArrays#240)
* Porting SDiagonal from Bridge.jl Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer at https://github.com/mschauer/Bridge.jl under MIT License * Inherit from AbstractMatrix * Fixes and tests for constructors * More tests, more fixes * Address andyferris' comments * SDiagonal: next round of comments * Remaining tests for SDiagonal * Where-notation in SDiagonal * Test for inv(::SDiagonal)
1 parent 3865c6f commit 4db5736

File tree

4 files changed

+216
-2
lines changed

4 files changed

+216
-2
lines changed

src/SDiagonal.jl

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer
2+
# at https://github.com/mschauer/Bridge.jl under MIT License
3+
4+
import Base: ==, -, +, *, /, \, abs, real, imag, conj
5+
6+
@generated function scalem(a::StaticMatrix{M,N}, b::StaticVector{N}) where {M, N}
7+
expr = vec([:(a[$j,$i]*b[$i]) for j=1:M, i=1:N])
8+
:(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end)
9+
end
10+
@generated function scalem(a::StaticVector{M}, b::StaticMatrix{M, N}) where {M, N}
11+
expr = vec([:(b[$j,$i]*a[$j]) for j=1:M, i=1:N])
12+
:(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end)
13+
end
14+
15+
struct SDiagonal{N,T} <: StaticMatrix{N,N,T}
16+
diag::SVector{N,T}
17+
SDiagonal{N,T}(diag::SVector{N,T}) where {N,T} = new(diag)
18+
end
19+
diagtype(::Type{SDiagonal{N,T}}) where {N, T} = SVector{N,T}
20+
diagtype(::Type{SDiagonal{N}}) where {N} = SVector{N}
21+
diagtype(::Type{SDiagonal}) = SVector
22+
23+
# this is to deal with convert.jl
24+
@inline (::Type{SD})(a::AbstractVector) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a))
25+
@inline (::Type{SD})(a::Tuple) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a))
26+
@inline (::Type{SDiagonal})(a::SVector{N,T}) where {N,T} = SDiagonal{N,T}(a)
27+
28+
@generated function SDiagonal(a::StaticMatrix{N,N,T}) where {N,T}
29+
expr = [:(a[$i,$i]) for i=1:N]
30+
:(SDiagonal{N,T}($(expr...)))
31+
end
32+
33+
convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) where {N,T} = D
34+
convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N}) where {N,T} = SDiagonal{N,T}(convert(SVector{N,T}, D.diag))
35+
36+
function getindex(D::SDiagonal{N,T}, i::Int, j::Int) where {N,T}
37+
@boundscheck checkbounds(D, i, j)
38+
@inbounds return ifelse(i == j, D.diag[i], zero(T))
39+
end
40+
41+
# avoid linear indexing?
42+
@propagate_inbounds function getindex(D::SDiagonal{N,T}, k::Int) where {N,T}
43+
i, j = ind2sub(size(D), k)
44+
D[i,j]
45+
end
46+
47+
ishermitian(D::SDiagonal{N, T}) where {N,T<:Real} = true
48+
ishermitian(D::SDiagonal) = all(D.diag .== real(D.diag))
49+
issymmetric(D::SDiagonal) = true
50+
isposdef(D::SDiagonal) = all(D.diag .> 0)
51+
52+
factorize(D::SDiagonal) = D
53+
54+
==(Da::SDiagonal, Db::SDiagonal) = Da.diag == Db.diag
55+
-(A::SDiagonal) = SDiagonal(-A.diag)
56+
+(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag + Db.diag)
57+
-(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag - Db.diag)
58+
-(A::SDiagonal, B::SMatrix) = eye(typeof(B))*A - B
59+
60+
*(x::T, D::SDiagonal) where {T<:Number} = SDiagonal(x * D.diag)
61+
*(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag * x)
62+
/(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag / x)
63+
*(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag .* Db.diag)
64+
*(D::SDiagonal, V::AbstractVector) = D.diag .* V
65+
*(D::SDiagonal, V::StaticVector) = D.diag .* V
66+
*(A::StaticMatrix, D::SDiagonal) = scalem(A,D.diag)
67+
*(D::SDiagonal, A::StaticMatrix) = scalem(D.diag,A)
68+
\(D::SDiagonal, b::AbstractVector) = D.diag .\ b
69+
\(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity
70+
71+
conj(D::SDiagonal) = SDiagonal(conj(D.diag))
72+
transpose(D::SDiagonal) = D
73+
ctranspose(D::SDiagonal) = conj(D)
74+
75+
diag(D::SDiagonal) = D.diag
76+
trace(D::SDiagonal) = sum(D.diag)
77+
det(D::SDiagonal) = prod(D.diag)
78+
logdet{N,T<:Real}(D::SDiagonal{N,T}) = sum(log.(D.diag))
79+
function logdet(D::SDiagonal{N,T}) where {N,T<:Complex} #Make sure branch cut is correct
80+
x = sum(log.(D.diag))
81+
-pi<imag(x)<pi ? x : real(x)+(mod2pi(imag(x)+pi)-pi)*im
82+
end
83+
84+
eye(::Type{SDiagonal{N,T}}) where {N,T} = SDiagonal(ones(SVector{N,T}))
85+
86+
expm(D::SDiagonal) = SDiagonal(exp.(D.diag))
87+
logm(D::SDiagonal) = SDiagonal(log.(D.diag))
88+
sqrtm(D::SDiagonal) = SDiagonal(sqrt.(D.diag))
89+
90+
\(D::SDiagonal, B::StaticMatrix) = scalem(1 ./ D.diag, B)
91+
/(B::StaticMatrix, D::SDiagonal) = scalem(1 ./ D.diag, B)
92+
\(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Db.diag ./ Da.diag)
93+
/(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag ./ Db.diag )
94+
95+
96+
@generated function check_singular(D::SDiagonal{N}) where {N}
97+
quote
98+
Base.Cartesian.@nexprs $N i->(@inbounds iszero(D.diag[i]) && throw(Base.LinAlg.SingularException(i)))
99+
end
100+
end
101+
102+
function inv(D::SDiagonal)
103+
check_singular(D)
104+
SDiagonal(inv.(D.diag))
105+
end
106+

src/StaticArrays.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import Base: getindex, setindex!, size, similar, vec, show,
88
length, convert, promote_op, promote_rule, map, map!, reduce, reducedim, mapreducedim,
99
mapreduce, broadcast, broadcast!, conj, transpose, ctranspose,
1010
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
11-
fill!, det, inv, eig, eigvals, expm, sqrtm, trace, vecnorm, norm, dot, diagm, diag,
12-
lu, svd, svdvals, svdfact,
11+
fill!, det, logdet, inv, eig, eigvals, expm, logm, sqrtm, trace, diag, vecnorm, norm, dot, diagm, diag,
12+
lu, svd, svdvals, svdfact, factorize, ishermitian, issymmetric, isposdef,
1313
sum, diff, prod, count, any, all, minimum,
1414
maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!,
1515
randexp!, normalize, normalize!, read, read!, write
@@ -19,6 +19,7 @@ export Scalar, SArray, SVector, SMatrix
1919
export MArray, MVector, MMatrix
2020
export FieldVector
2121
export SizedArray, SizedVector, SizedMatrix
22+
export SDiagonal
2223

2324
export Size, Length
2425

@@ -79,6 +80,7 @@ include("MArray.jl")
7980
include("MVector.jl")
8081
include("MMatrix.jl")
8182
include("SizedArray.jl")
83+
include("SDiagonal.jl")
8284

8385
include("abstractarray.jl")
8486
include("indexing.jl")

test/SDiagonal.jl

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
@testset "SDiagonal" begin
2+
@testset "Constructors" begin
3+
@test SDiagonal{1,Int64}((1,)).diag === SVector{1,Int64}((1,))
4+
@test SDiagonal{1,Float64}((1,)).diag === SVector{1,Float64}((1,))
5+
6+
@test SDiagonal{4,Float64}((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0)
7+
@test SDiagonal{4}((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0)
8+
@test SDiagonal((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0)
9+
10+
# Bad input
11+
@test_throws Exception SMatrix{1,Int}()
12+
@test_throws Exception SMatrix{2,Int}((1,))
13+
14+
# From SMatrix
15+
@test SDiagonal(SMatrix{2,2,Int}((1,2,3,4))).diag.data === (1,4)
16+
17+
@test SDiagonal{1,Int}(SDiagonal{1,Float64}((1,))).diag[1] === 1
18+
19+
end
20+
21+
@testset "Methods" begin
22+
23+
@test StaticArrays.scalem(@SMatrix([1 1 1;1 1 1; 1 1 1]), @SVector [1,2,3]) === @SArray [1 2 3; 1 2 3; 1 2 3]
24+
@test StaticArrays.scalem(@SVector([1,2,3]),@SMatrix [1 1 1;1 1 1; 1 1 1])' === @SArray [1 2 3; 1 2 3; 1 2 3]
25+
26+
m = SDiagonal(@SVector [11, 12, 13, 14])
27+
28+
29+
30+
@test diag(m) === m.diag
31+
32+
33+
m2 = diagm([11, 12, 13, 14])
34+
35+
@test logdet(m) == logdet(m2)
36+
@test logdet(im*m) logdet(im*m2)
37+
@test det(m) == det(m2)
38+
@test trace(m) == trace(m2)
39+
@test logm(m) == logm(m2)
40+
@test expm(m) == expm(m2)
41+
@test sqrtm(m) == sqrtm(m2)
42+
43+
44+
@test isimmutable(m) == true
45+
46+
@test m[1,1] === 11
47+
@test m[2,2] === 12
48+
@test m[3,3] === 13
49+
@test m[4,4] === 14
50+
51+
for i in 1:4
52+
for j in 1:4
53+
i == j || @test m[i,j] === 0
54+
end
55+
end
56+
57+
@test_throws Exception m[5,5]
58+
59+
@test_throws Exception m[1,5]
60+
61+
62+
@test size(m) === (4, 4)
63+
@test size(typeof(m)) === (4, 4)
64+
@test size(SDiagonal{4}) === (4, 4)
65+
66+
@test size(m, 1) === 4
67+
@test size(m, 2) === 4
68+
@test size(typeof(m), 1) === 4
69+
@test size(typeof(m), 2) === 4
70+
71+
@test length(m) === 4*4
72+
73+
@test_throws Exception m[1] = 1
74+
75+
b = @SVector [2,-1,2,1]
76+
b2 = Vector(b)
77+
78+
79+
@test m*b == @SVector [22,-12,26,14]
80+
@test (b'*m)' == @SVector [22,-12,26,14]
81+
82+
@test m\b == m2\b
83+
84+
@test b'/m == b'/m2
85+
@test_throws Exception b/m
86+
@test m*m == m2*m
87+
88+
@test ishermitian(m) == ishermitian(m2)
89+
@test ishermitian(m/2)
90+
91+
@test isposdef(m) == isposdef(m2)
92+
@test issymmetric(m) == issymmetric(m2)
93+
94+
@test (2*m/2)' == m
95+
@test 2m == m + m
96+
@test m*0 == m - m
97+
98+
@test m*inv(m) == m/m == m\m == eye(SDiagonal{4,Float64})
99+
100+
101+
102+
103+
end
104+
end

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ include("FieldVector.jl")
1212
include("Scalar.jl")
1313
include("SUnitRange.jl")
1414
include("SizedArray.jl")
15+
include("SDiagonal.jl")
16+
1517
include("custom_types.jl")
1618
include("core.jl")
1719
include("abstractarray.jl")

0 commit comments

Comments
 (0)