|
| 1 | +# Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer |
| 2 | +# at https://github.com/mschauer/Bridge.jl under MIT License |
| 3 | + |
| 4 | +import Base: ==, -, +, *, /, \, abs, real, imag, conj |
| 5 | + |
| 6 | +@generated function scalem(a::StaticMatrix{M,N}, b::StaticVector{N}) where {M, N} |
| 7 | + expr = vec([:(a[$j,$i]*b[$i]) for j=1:M, i=1:N]) |
| 8 | + :(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end) |
| 9 | +end |
| 10 | +@generated function scalem(a::StaticVector{M}, b::StaticMatrix{M, N}) where {M, N} |
| 11 | + expr = vec([:(b[$j,$i]*a[$j]) for j=1:M, i=1:N]) |
| 12 | + :(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end) |
| 13 | +end |
| 14 | + |
| 15 | +struct SDiagonal{N,T} <: StaticMatrix{N,N,T} |
| 16 | + diag::SVector{N,T} |
| 17 | + SDiagonal{N,T}(diag::SVector{N,T}) where {N,T} = new(diag) |
| 18 | +end |
| 19 | +diagtype(::Type{SDiagonal{N,T}}) where {N, T} = SVector{N,T} |
| 20 | +diagtype(::Type{SDiagonal{N}}) where {N} = SVector{N} |
| 21 | +diagtype(::Type{SDiagonal}) = SVector |
| 22 | + |
| 23 | +# this is to deal with convert.jl |
| 24 | +@inline (::Type{SD})(a::AbstractVector) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a)) |
| 25 | +@inline (::Type{SD})(a::Tuple) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a)) |
| 26 | +@inline (::Type{SDiagonal})(a::SVector{N,T}) where {N,T} = SDiagonal{N,T}(a) |
| 27 | + |
| 28 | +@generated function SDiagonal(a::StaticMatrix{N,N,T}) where {N,T} |
| 29 | + expr = [:(a[$i,$i]) for i=1:N] |
| 30 | + :(SDiagonal{N,T}($(expr...))) |
| 31 | +end |
| 32 | + |
| 33 | +convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) where {N,T} = D |
| 34 | +convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N}) where {N,T} = SDiagonal{N,T}(convert(SVector{N,T}, D.diag)) |
| 35 | + |
| 36 | +function getindex(D::SDiagonal{N,T}, i::Int, j::Int) where {N,T} |
| 37 | + @boundscheck checkbounds(D, i, j) |
| 38 | + @inbounds return ifelse(i == j, D.diag[i], zero(T)) |
| 39 | +end |
| 40 | + |
| 41 | +# avoid linear indexing? |
| 42 | +@propagate_inbounds function getindex(D::SDiagonal{N,T}, k::Int) where {N,T} |
| 43 | + i, j = ind2sub(size(D), k) |
| 44 | + D[i,j] |
| 45 | +end |
| 46 | + |
| 47 | +ishermitian(D::SDiagonal{N, T}) where {N,T<:Real} = true |
| 48 | +ishermitian(D::SDiagonal) = all(D.diag .== real(D.diag)) |
| 49 | +issymmetric(D::SDiagonal) = true |
| 50 | +isposdef(D::SDiagonal) = all(D.diag .> 0) |
| 51 | + |
| 52 | +factorize(D::SDiagonal) = D |
| 53 | + |
| 54 | +==(Da::SDiagonal, Db::SDiagonal) = Da.diag == Db.diag |
| 55 | +-(A::SDiagonal) = SDiagonal(-A.diag) |
| 56 | ++(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag + Db.diag) |
| 57 | +-(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag - Db.diag) |
| 58 | +-(A::SDiagonal, B::SMatrix) = eye(typeof(B))*A - B |
| 59 | + |
| 60 | +*(x::T, D::SDiagonal) where {T<:Number} = SDiagonal(x * D.diag) |
| 61 | +*(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag * x) |
| 62 | +/(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag / x) |
| 63 | +*(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag .* Db.diag) |
| 64 | +*(D::SDiagonal, V::AbstractVector) = D.diag .* V |
| 65 | +*(D::SDiagonal, V::StaticVector) = D.diag .* V |
| 66 | +*(A::StaticMatrix, D::SDiagonal) = scalem(A,D.diag) |
| 67 | +*(D::SDiagonal, A::StaticMatrix) = scalem(D.diag,A) |
| 68 | +\(D::SDiagonal, b::AbstractVector) = D.diag .\ b |
| 69 | +\(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity |
| 70 | + |
| 71 | +conj(D::SDiagonal) = SDiagonal(conj(D.diag)) |
| 72 | +transpose(D::SDiagonal) = D |
| 73 | +ctranspose(D::SDiagonal) = conj(D) |
| 74 | + |
| 75 | +diag(D::SDiagonal) = D.diag |
| 76 | +trace(D::SDiagonal) = sum(D.diag) |
| 77 | +det(D::SDiagonal) = prod(D.diag) |
| 78 | +logdet{N,T<:Real}(D::SDiagonal{N,T}) = sum(log.(D.diag)) |
| 79 | +function logdet(D::SDiagonal{N,T}) where {N,T<:Complex} #Make sure branch cut is correct |
| 80 | + x = sum(log.(D.diag)) |
| 81 | + -pi<imag(x)<pi ? x : real(x)+(mod2pi(imag(x)+pi)-pi)*im |
| 82 | +end |
| 83 | + |
| 84 | +eye(::Type{SDiagonal{N,T}}) where {N,T} = SDiagonal(ones(SVector{N,T})) |
| 85 | + |
| 86 | +expm(D::SDiagonal) = SDiagonal(exp.(D.diag)) |
| 87 | +logm(D::SDiagonal) = SDiagonal(log.(D.diag)) |
| 88 | +sqrtm(D::SDiagonal) = SDiagonal(sqrt.(D.diag)) |
| 89 | + |
| 90 | +\(D::SDiagonal, B::StaticMatrix) = scalem(1 ./ D.diag, B) |
| 91 | +/(B::StaticMatrix, D::SDiagonal) = scalem(1 ./ D.diag, B) |
| 92 | +\(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Db.diag ./ Da.diag) |
| 93 | +/(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag ./ Db.diag ) |
| 94 | + |
| 95 | + |
| 96 | +@generated function check_singular(D::SDiagonal{N}) where {N} |
| 97 | + quote |
| 98 | + Base.Cartesian.@nexprs $N i->(@inbounds iszero(D.diag[i]) && throw(Base.LinAlg.SingularException(i))) |
| 99 | + end |
| 100 | +end |
| 101 | + |
| 102 | +function inv(D::SDiagonal) |
| 103 | + check_singular(D) |
| 104 | + SDiagonal(inv.(D.diag)) |
| 105 | +end |
| 106 | + |
0 commit comments