Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting SDiagonal from Bridge.jl #240

Merged
merged 9 commits into from
Jul 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions src/SDiagonal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, if you're porting someone else's code, and you want to attribute them, it's also quite possible to make them the author in git using something like git commit --author="Some Body <[email protected]>"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I do this kind of thing, I'd be inclined to add the original chunk of code (in probably non-working form) under the other author's name as a single commit (@mentioning them to make sure they're happy with that). Then add any necessary changes under your own name in further commits.

The nice thing about doing it this way is you avoid baking authorship into comments which people feel they can't remove (like the one above), but you also get to do proper attribution which I feel is quite important.

Just some thoughts, I'm happy this was merged already.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you port entire files (like in this case) it is also possible to preserve the full git history http://gbayer.com/development/moving-files-from-one-git-repository-to-another-preserving-history/

# at https://github.com/mschauer/Bridge.jl under MIT License

import Base: ==, -, +, *, /, \, abs, real, imag, conj

@generated function scalem(a::StaticMatrix{M,N}, b::StaticVector{N}) where {M, N}
expr = vec([:(a[$j,$i]*b[$i]) for j=1:M, i=1:N])
:(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end)
end
@generated function scalem(a::StaticVector{M}, b::StaticMatrix{M, N}) where {M, N}
expr = vec([:(b[$j,$i]*a[$j]) for j=1:M, i=1:N])
:(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end)
end

struct SDiagonal{N,T} <: StaticMatrix{N,N,T}
diag::SVector{N,T}
SDiagonal{N,T}(diag::SVector{N,T}) where {N,T} = new(diag)
end
diagtype(::Type{SDiagonal{N,T}}) where {N, T} = SVector{N,T}
diagtype(::Type{SDiagonal{N}}) where {N} = SVector{N}
diagtype(::Type{SDiagonal}) = SVector

# this is to deal with convert.jl
@inline (::Type{SD})(a::AbstractVector) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a))
@inline (::Type{SD})(a::Tuple) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a))
@inline (::Type{SDiagonal})(a::SVector{N,T}) where {N,T} = SDiagonal{N,T}(a)

@generated function SDiagonal(a::StaticMatrix{N,N,T}) where {N,T}
expr = [:(a[$i,$i]) for i=1:N]
:(SDiagonal{N,T}($(expr...)))
end

convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) where {N,T} = D
convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N}) where {N,T} = SDiagonal{N,T}(convert(SVector{N,T}, D.diag))

function getindex(D::SDiagonal{N,T}, i::Int, j::Int) where {N,T}
@boundscheck checkbounds(D, i, j)
@inbounds return ifelse(i == j, D.diag[i], zero(T))
end

# avoid linear indexing?
@propagate_inbounds function getindex(D::SDiagonal{N,T}, k::Int) where {N,T}
i, j = ind2sub(size(D), k)
D[i,j]
end

ishermitian(D::SDiagonal{N, T}) where {N,T<:Real} = true
ishermitian(D::SDiagonal) = all(D.diag .== real(D.diag))
issymmetric(D::SDiagonal) = true
isposdef(D::SDiagonal) = all(D.diag .> 0)

factorize(D::SDiagonal) = D

==(Da::SDiagonal, Db::SDiagonal) = Da.diag == Db.diag
-(A::SDiagonal) = SDiagonal(-A.diag)
+(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag + Db.diag)
-(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag - Db.diag)
-(A::SDiagonal, B::SMatrix) = eye(typeof(B))*A - B

*(x::T, D::SDiagonal) where {T<:Number} = SDiagonal(x * D.diag)
*(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag * x)
/(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag / x)
*(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag .* Db.diag)
*(D::SDiagonal, V::AbstractVector) = D.diag .* V
*(D::SDiagonal, V::StaticVector) = D.diag .* V
*(A::StaticMatrix, D::SDiagonal) = scalem(A,D.diag)
*(D::SDiagonal, A::StaticMatrix) = scalem(D.diag,A)
\(D::SDiagonal, b::AbstractVector) = D.diag .\ b
\(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity

conj(D::SDiagonal) = SDiagonal(conj(D.diag))
transpose(D::SDiagonal) = D
ctranspose(D::SDiagonal) = conj(D)

diag(D::SDiagonal) = D.diag
trace(D::SDiagonal) = sum(D.diag)
det(D::SDiagonal) = prod(D.diag)
logdet{N,T<:Real}(D::SDiagonal{N,T}) = sum(log.(D.diag))
function logdet(D::SDiagonal{N,T}) where {N,T<:Complex} #Make sure branch cut is correct
x = sum(log.(D.diag))
-pi<imag(x)<pi ? x : real(x)+(mod2pi(imag(x)+pi)-pi)*im
end

eye(::Type{SDiagonal{N,T}}) where {N,T} = SDiagonal(ones(SVector{N,T}))

expm(D::SDiagonal) = SDiagonal(exp.(D.diag))
logm(D::SDiagonal) = SDiagonal(log.(D.diag))
sqrtm(D::SDiagonal) = SDiagonal(sqrt.(D.diag))

\(D::SDiagonal, B::StaticMatrix) = scalem(1 ./ D.diag, B)
/(B::StaticMatrix, D::SDiagonal) = scalem(1 ./ D.diag, B)
\(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Db.diag ./ Da.diag)
/(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag ./ Db.diag )


@generated function check_singular(D::SDiagonal{N}) where {N}
quote
Base.Cartesian.@nexprs $N i->(@inbounds iszero(D.diag[i]) && throw(Base.LinAlg.SingularException(i)))
end
end

function inv(D::SDiagonal)
check_singular(D)
SDiagonal(inv.(D.diag))
end

6 changes: 4 additions & 2 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import Base: getindex, setindex!, size, similar, vec, show,
length, convert, promote_op, promote_rule, map, map!, reduce, reducedim, mapreducedim,
mapreduce, broadcast, broadcast!, conj, transpose, ctranspose,
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
fill!, det, inv, eig, eigvals, expm, sqrtm, trace, vecnorm, norm, dot, diagm, diag,
lu, svd, svdvals, svdfact,
fill!, det, logdet, inv, eig, eigvals, expm, logm, sqrtm, trace, diag, vecnorm, norm, dot, diagm, diag,
lu, svd, svdvals, svdfact, factorize, ishermitian, issymmetric, isposdef,
sum, diff, prod, count, any, all, minimum,
maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!,
randexp!, normalize, normalize!, read, read!, write
Expand All @@ -19,6 +19,7 @@ export Scalar, SArray, SVector, SMatrix
export MArray, MVector, MMatrix
export FieldVector
export SizedArray, SizedVector, SizedMatrix
export SDiagonal

export Size, Length

Expand Down Expand Up @@ -79,6 +80,7 @@ include("MArray.jl")
include("MVector.jl")
include("MMatrix.jl")
include("SizedArray.jl")
include("SDiagonal.jl")

include("abstractarray.jl")
include("indexing.jl")
Expand Down
104 changes: 104 additions & 0 deletions test/SDiagonal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
@testset "SDiagonal" begin
@testset "Constructors" begin
@test SDiagonal{1,Int64}((1,)).diag === SVector{1,Int64}((1,))
@test SDiagonal{1,Float64}((1,)).diag === SVector{1,Float64}((1,))

@test SDiagonal{4,Float64}((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0)
@test SDiagonal{4}((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0)
@test SDiagonal((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0)

# Bad input
@test_throws Exception SMatrix{1,Int}()
@test_throws Exception SMatrix{2,Int}((1,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@test_throws Exception is generally a very bad idea. It will pass if the code has a typo.


# From SMatrix
@test SDiagonal(SMatrix{2,2,Int}((1,2,3,4))).diag.data === (1,4)

@test SDiagonal{1,Int}(SDiagonal{1,Float64}((1,))).diag[1] === 1

end

@testset "Methods" begin

@test StaticArrays.scalem(@SMatrix([1 1 1;1 1 1; 1 1 1]), @SVector [1,2,3]) === @SArray [1 2 3; 1 2 3; 1 2 3]
@test StaticArrays.scalem(@SVector([1,2,3]),@SMatrix [1 1 1;1 1 1; 1 1 1])' === @SArray [1 2 3; 1 2 3; 1 2 3]

m = SDiagonal(@SVector [11, 12, 13, 14])



@test diag(m) === m.diag


m2 = diagm([11, 12, 13, 14])

@test logdet(m) == logdet(m2)
@test logdet(im*m) ≈ logdet(im*m2)
@test det(m) == det(m2)
@test trace(m) == trace(m2)
@test logm(m) == logm(m2)
@test expm(m) == expm(m2)
@test sqrtm(m) == sqrtm(m2)


@test isimmutable(m) == true

@test m[1,1] === 11
@test m[2,2] === 12
@test m[3,3] === 13
@test m[4,4] === 14

for i in 1:4
for j in 1:4
i == j || @test m[i,j] === 0
end
end

@test_throws Exception m[5,5]

@test_throws Exception m[1,5]


@test size(m) === (4, 4)
@test size(typeof(m)) === (4, 4)
@test size(SDiagonal{4}) === (4, 4)

@test size(m, 1) === 4
@test size(m, 2) === 4
@test size(typeof(m), 1) === 4
@test size(typeof(m), 2) === 4

@test length(m) === 4*4

@test_throws Exception m[1] = 1

b = @SVector [2,-1,2,1]
b2 = Vector(b)


@test m*b == @SVector [22,-12,26,14]
@test (b'*m)' == @SVector [22,-12,26,14]

@test m\b == m2\b

@test b'/m == b'/m2
@test_throws Exception b/m
@test m*m == m2*m

@test ishermitian(m) == ishermitian(m2)
@test ishermitian(m/2)

@test isposdef(m) == isposdef(m2)
@test issymmetric(m) == issymmetric(m2)

@test (2*m/2)' == m
@test 2m == m + m
@test m*0 == m - m

@test m*inv(m) == m/m == m\m == eye(SDiagonal{4,Float64})




end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ include("FieldVector.jl")
include("Scalar.jl")
include("SUnitRange.jl")
include("SizedArray.jl")
include("SDiagonal.jl")

include("custom_types.jl")

include("core.jl")
Expand Down