Skip to content

Commit 57f2526

Browse files
committed
Fixes and tests for constructors
1 parent 1eeb98b commit 57f2526

File tree

3 files changed

+64
-10
lines changed

3 files changed

+64
-10
lines changed

src/SDiagonal.jl

+11-10
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@ struct SDiagonal{N,T} <: StaticMatrix{N, N, T}
1717
diag::SVector{N,T}
1818
SDiagonal{N,T}(diag::SVector{N,T}) where {N,T} = new(diag)
1919
end
20+
diagtype{N,T}(::Type{SDiagonal{N,T}}) = SVector{N,T}
21+
diagtype{N}(::Type{SDiagonal{N}}) = SVector{N}
2022

2123
# this is to deal with convert.jl
22-
@inline (::Type{SDiagonal})(a::AbstractVector) = SDiagonal(SVector(a))
24+
@inline (::Type{SD})(a::AbstractVector) where {SD <: SDiagonal} = SD(diagtype(SD)(a))
25+
@inline (::Type{SD})(a::Tuple) where {SD <: SDiagonal} = SD(diagtype(SD)(a))
2326
@inline (::Type{SDiagonal}){N,T}(a::SVector{N,T}) = SDiagonal{N,T}(a)
2427

2528
@generated function SDiagonal{N,T}(a::SMatrix{N,N,T})
@@ -34,23 +37,25 @@ end
3437
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) = D
3538
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal) = SDiagonal{N,T}(convert(SVector{N,T}, D.diag))
3639

37-
size(D::SDiagonal) = (length(D.diag),length(D.diag))
40+
size{N}(D::SDiagonal{N}) = (N,N)
3841

39-
function size(D::SDiagonal,d::Integer)
42+
function size{N}(D::SDiagonal{N},d::Int64)
4043
if d<1
4144
throw(ArgumentError("dimension must be ≥ 1, got $d"))
4245
end
43-
return d<=2 ? length(D.diag) : 1
46+
return d<=2 ? N : 1
4447
end
4548

46-
function getindex{T}(D::SDiagonal{T}, i::Int, j::Int)
49+
Base.@propagate_inbounds function getindex{T}(D::SDiagonal{T}, i::Int, j::Int)
50+
@boundscheck checkbounds(D, i, j)
4751
if i == j
4852
D.diag[i]
4953
else
5054
zero(T)
5155
end
5256
end
5357

58+
# linear indexing?
5459

5560
ishermitian{T<:Real}(D::SDiagonal{T}) = true
5661
ishermitian(D::SDiagonal) = all(D.diag .== real(D.diag))
@@ -59,10 +64,6 @@ isposdef(D::SDiagonal) = all(D.diag .> 0)
5964

6065
factorize(D::SDiagonal) = D
6166

62-
abs(D::SDiagonal) = SDiagonal(abs(D.diag))
63-
real(D::SDiagonal) = SDiagonal(real(D.diag))
64-
imag(D::SDiagonal) = SDiagonal(imag(D.diag))
65-
6667
==(Da::SDiagonal, Db::SDiagonal) = Da.diag == Db.diag
6768
-(A::SDiagonal) = SDiagonal(-A.diag)
6869
+(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag + Db.diag)
@@ -95,7 +96,7 @@ function logdet{N,T<:Complex}(D::SDiagonal{N,T}) #Make sure branch cut is correc
9596
end
9697

9798

98-
eye{N,T}(::Type{SDiagonal{N,T}}) = SDiagonal(one(SVector{n,Int}))
99+
eye{N,T}(::Type{SDiagonal{N,T}}) = SDiagonal(one(SVector{N,T}))
99100

100101
expm(D::SDiagonal) = SDiagonal(exp.(D.diag))
101102
logm(D::SDiagonal) = SDiagonal(log.(D.diag))

test/SDiagonal.jl

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
@testset "SDiagonal" begin
2+
@testset "Constructors" begin
3+
@test SDiagonal{1,Int64}((1,)).diag === SVector{1,Int64}((1,))
4+
@test SDiagonal{1,Float64}((1,)).diag === SVector{1,Float64}((1,))
5+
6+
@test SDiagonal{4,Float64}((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0)
7+
8+
# Bad input
9+
@test_throws Exception SMatrix{1,Int}()
10+
@test_throws Exception SMatrix{2,Int}((1,))
11+
12+
# From SMatrix
13+
@test SDiagonal(SMatrix{2,2,Int}((1,2,3,4))).diag.data === (1,4)
14+
15+
end
16+
17+
@testset "Methods" begin
18+
m = SDiagonal(@SVector [11, 12, 13, 14])
19+
20+
@test isimmutable(m) == true
21+
22+
@test m[1,1] === 11
23+
@test m[2,2] === 12
24+
@test m[3,3] === 13
25+
@test m[4,4] === 14
26+
27+
for i in 1:4
28+
for j in 1:4
29+
i == j || @test m[i,j] === 0
30+
end
31+
end
32+
33+
@test_throws Exception m[5,5]
34+
35+
@test_throws Exception m[1,5]
36+
37+
38+
@test size(m) === (4, 4)
39+
@test size(typeof(m)) === (4, 4)
40+
@test size(SDiagonal{4}) === (4, 4)
41+
42+
@test size(m, 1) === 4
43+
@test size(m, 2) === 4
44+
@test size(typeof(m), 1) === 4
45+
@test size(typeof(m), 2) === 4
46+
47+
@test length(m) === 4*4
48+
49+
@test_throws Exception m[1] = 1
50+
end
51+
end

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ include("FieldVector.jl")
1313
include("Scalar.jl")
1414
include("SUnitRange.jl")
1515
include("SizedArray.jl")
16+
include("SDiagonal.jl")
17+
1618
include("custom_types.jl")
1719

1820
include("core.jl")

0 commit comments

Comments
 (0)