@@ -17,9 +17,12 @@ struct SDiagonal{N,T} <: StaticMatrix{N, N, T}
17
17
diag:: SVector{N,T}
18
18
SDiagonal {N,T} (diag:: SVector{N,T} ) where {N,T} = new (diag)
19
19
end
20
+ diagtype {N,T} (:: Type{SDiagonal{N,T}} ) = SVector{N,T}
21
+ diagtype {N} (:: Type{SDiagonal{N}} ) = SVector{N}
20
22
21
23
# this is to deal with convert.jl
22
- @inline (:: Type{SDiagonal} )(a:: AbstractVector ) = SDiagonal (SVector (a))
24
+ @inline (:: Type{SD} )(a:: AbstractVector ) where {SD <: SDiagonal } = SD (diagtype (SD)(a))
25
+ @inline (:: Type{SD} )(a:: Tuple ) where {SD <: SDiagonal } = SD (diagtype (SD)(a))
23
26
@inline (:: Type{SDiagonal} ){N,T}(a:: SVector{N,T} ) = SDiagonal {N,T} (a)
24
27
25
28
@generated function SDiagonal {N,T} (a:: SMatrix{N,N,T} )
34
37
convert {N,T} (:: Type{SDiagonal{N,T}} , D:: SDiagonal{N,T} ) = D
35
38
convert {N,T} (:: Type{SDiagonal{N,T}} , D:: SDiagonal ) = SDiagonal {N,T} (convert (SVector{N,T}, D. diag))
36
39
37
- size (D:: SDiagonal ) = (length (D . diag), length (D . diag) )
40
+ size {N} (D:: SDiagonal{N} ) = (N,N )
38
41
39
- function size (D:: SDiagonal ,d:: Integer )
42
+ function size {N} (D:: SDiagonal{N} ,d:: Int64 )
40
43
if d< 1
41
44
throw (ArgumentError (" dimension must be ≥ 1, got $d " ))
42
45
end
43
- return d<= 2 ? length (D . diag) : 1
46
+ return d<= 2 ? N : 1
44
47
end
45
48
46
- function getindex {T} (D:: SDiagonal{T} , i:: Int , j:: Int )
49
+ Base. @propagate_inbounds function getindex {T} (D:: SDiagonal{T} , i:: Int , j:: Int )
50
+ @boundscheck checkbounds (D, i, j)
47
51
if i == j
48
52
D. diag[i]
49
53
else
50
54
zero (T)
51
55
end
52
56
end
53
57
58
+ # linear indexing?
54
59
55
60
ishermitian {T<:Real} (D:: SDiagonal{T} ) = true
56
61
ishermitian (D:: SDiagonal ) = all (D. diag .== real (D. diag))
@@ -59,10 +64,6 @@ isposdef(D::SDiagonal) = all(D.diag .> 0)
59
64
60
65
factorize (D:: SDiagonal ) = D
61
66
62
- abs (D:: SDiagonal ) = SDiagonal (abs (D. diag))
63
- real (D:: SDiagonal ) = SDiagonal (real (D. diag))
64
- imag (D:: SDiagonal ) = SDiagonal (imag (D. diag))
65
-
66
67
== (Da:: SDiagonal , Db:: SDiagonal ) = Da. diag == Db. diag
67
68
- (A:: SDiagonal ) = SDiagonal (- A. diag)
68
69
+ (Da:: SDiagonal , Db:: SDiagonal ) = SDiagonal (Da. diag + Db. diag)
@@ -95,7 +96,7 @@ function logdet{N,T<:Complex}(D::SDiagonal{N,T}) #Make sure branch cut is correc
95
96
end
96
97
97
98
98
- eye {N,T} (:: Type{SDiagonal{N,T}} ) = SDiagonal (one (SVector{n,Int }))
99
+ eye {N,T} (:: Type{SDiagonal{N,T}} ) = SDiagonal (one (SVector{N,T }))
99
100
100
101
expm (D:: SDiagonal ) = SDiagonal (exp .(D. diag))
101
102
logm (D:: SDiagonal ) = SDiagonal (log .(D. diag))
0 commit comments