Skip to content

Commit 4f3081d

Browse files
committed
Porting SDiagonal from Bridge.jl
Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer at https://github.com/mschauer/Bridge.jl under MIT License
1 parent 80b6bac commit 4f3081d

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

src/SDiagonal.jl

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer
2+
# at https://github.com/mschauer/Bridge.jl under MIT License
3+
4+
import Base: getindex,setindex!,==,-,+,*,/,\,transpose,ctranspose,convert, size, abs, real, imag, conj, eye, inv
5+
import Base.LinAlg: ishermitian, issymmetric, isposdef, factorize, diag, trace, det, logdet, expm, logm, sqrtm
6+
7+
@generated function scalem{T, M, N}(a::SMatrix{M,N, T}, b::SVector{N, T})
8+
expr = vec([:(a[$j,$i]*b[$i]) for j=1:M, i=1:N])
9+
:(SMatrix{M,N,T}($(expr...)))
10+
end
11+
@generated function scalem{T, M, N}(a::SVector{M,T}, b::SMatrix{M, N, T})
12+
expr = vec([:(b[$j,$i]*a[$j]) for j=1:M, i=1:N])
13+
:(SMatrix{M,N,T}($(expr...)))
14+
end
15+
16+
if ! @_isdefined(SDiagonal)
17+
struct SDiagonal{N,T}
18+
diag::SVector{N,T}
19+
end
20+
end
21+
22+
function \{T,M}(D::SDiagonal, b::SVector{M,T} )
23+
D.diag .* b
24+
end
25+
26+
SDiagonal(A::SMatrix) = SDiagonal(diag(A))
27+
28+
29+
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) = D
30+
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal) = SDiagonal{N,T}(convert(SVector{N,T}, D.diag))
31+
32+
size(D::SDiagonal) = (length(D.diag),length(D.diag))
33+
34+
function size(D::SDiagonal,d::Integer)
35+
if d<1
36+
throw(ArgumentError("dimension must be ≥ 1, got $d"))
37+
end
38+
return d<=2 ? length(D.diag) : 1
39+
end
40+
41+
function getindex{T}(D::SDiagonal{T}, i::Int, j::Int)
42+
if i == j
43+
D.diag[i]
44+
else
45+
zero(T)
46+
end
47+
end
48+
function setindex!(D::SDiagonal, v, i::Int, j::Int)
49+
if i == j
50+
unsafe_setindex!(D.diag, v, i)
51+
elseif v != 0
52+
throw(ArgumentError("cannot set an off-diagonal index ($i, $j) to a nonzero value ($v)"))
53+
end
54+
D
55+
end
56+
57+
ishermitian{T<:Real}(D::SDiagonal{T}) = true
58+
ishermitian(D::SDiagonal) = all(D.diag .== real(D.diag))
59+
issym(D::SDiagonal) = true
60+
isposdef(D::SDiagonal) = all(D.diag .> 0)
61+
62+
factorize(D::SDiagonal) = D
63+
64+
abs(D::SDiagonal) = SDiagonal(abs(D.diag))
65+
real(D::SDiagonal) = SDiagonal(real(D.diag))
66+
imag(D::SDiagonal) = SDiagonal(imag(D.diag))
67+
68+
==(Da::SDiagonal, Db::SDiagonal) = Da.diag == Db.diag
69+
-(A::SDiagonal) = SDiagonal(-A.diag)
70+
+(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag + Db.diag)
71+
-(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag - Db.diag)
72+
-(A::SDiagonal, B::SMatrix) = eye(typeof(B))*A - B
73+
74+
75+
*{T<:Number}(x::T, D::SDiagonal) = SDiagonal(x * D.diag)
76+
*{T<:Number}(D::SDiagonal, x::T) = SDiagonal(D.diag * x)
77+
/{T<:Number}(D::SDiagonal, x::T) = SDiagonal(D.diag / x)
78+
*(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag .* Db.diag)
79+
*(D::SDiagonal, V::SVector) = D.diag .* V
80+
*(V::SVector, D::SDiagonal) = D.diag .* V
81+
*(A::SMatrix, D::SDiagonal) = scalem(A,D.diag)
82+
*(D::SDiagonal, A::SMatrix) = scalem(D.diag,A)
83+
84+
/(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag ./ Db.diag )
85+
86+
conj(D::SDiagonal) = SDiagonal(conj(D.diag))
87+
transpose(D::SDiagonal) = D
88+
ctranspose(D::SDiagonal) = conj(D)
89+
90+
diag(D::SDiagonal) = D.diag
91+
trace(D::SDiagonal) = sum(D.diag)
92+
det(D::SDiagonal) = prod(D.diag)
93+
logdet{N,T<:Real}(D::SDiagonal{N,T}) = sum(log.(D.diag))
94+
function logdet{N,T<:Complex}(D::SDiagonal{N,T}) #Make sure branch cut is correct
95+
x = sum(log.(D.diag))
96+
-pi<imag(x)<pi ? x : real(x)+(mod2pi(imag(x)+pi)-pi)*im
97+
end
98+
99+
100+
eye{N,T}(::Type{SDiagonal{N,T}}) = SDiagonal(one(SVector{n,Int}))
101+
102+
expm(D::SDiagonal) = SDiagonal(exp.(D.diag))
103+
logm(D::SDiagonal) = SDiagonal(log.(D.diag))
104+
sqrtm(D::SDiagonal) = SDiagonal(sqrt.(D.diag))
105+
106+
\(D::SDiagonal, B::SMatrix) = scalem(1 ./ D.diag, B)
107+
/(B::SMatrix, D::SDiagonal) = scalem(1 ./ D.diag, B)
108+
\(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Db.diag ./ Da.diag)
109+
110+
function inv{N,T}(D::SDiagonal{N,T})
111+
for i = 1:length(D.diag)
112+
if D.diag[i] == zero(T)
113+
throw(SingularException(i))
114+
end
115+
end
116+
SDiagonal(one(T)./D.diag)
117+
end
118+

src/StaticArrays.jl

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ include("MArray.jl")
7979
include("MVector.jl")
8080
include("MMatrix.jl")
8181
include("SizedArray.jl")
82+
include("SDiagonal.jl")
8283

8384
include("abstractarray.jl")
8485
include("indexing.jl")

0 commit comments

Comments
 (0)