Skip to content

Commit 7708986

Browse files
authored
Specialize triu/tril for StaticMatrix (JuliaArrays#1241)
1 parent 431d57a commit 7708986

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

Diff for: src/StaticArrays.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ using LinearAlgebra
1616
import LinearAlgebra: transpose, adjoint, dot, eigvals, eigen, lyap, tr,
1717
kron, diag, norm, dot, diagm, lu, svd, svdvals, pinv,
1818
factorize, ishermitian, issymmetric, isposdef, issuccess, normalize,
19-
normalize!, Eigen, det, logdet, logabsdet, cross, diff, qr, \
19+
normalize!, Eigen, det, logdet, logabsdet, cross, diff, qr, \,
20+
triu, tril
2021
using LinearAlgebra: checksquare
2122

2223
using PrecompileTools

Diff for: src/linalg.jl

+34
Original file line numberDiff line numberDiff line change
@@ -522,3 +522,37 @@ end
522522
# Some shimming for special linear algebra matrix types
523523
@inline LinearAlgebra.Symmetric(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Symmetric{eltype(A),typeof(A)}(A, uplo))
524524
@inline LinearAlgebra.Hermitian(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Hermitian{eltype(A),typeof(A)}(A, uplo))
525+
526+
# triu/tril
527+
function triu(S::StaticMatrix, k::Int=0)
528+
if length(S) <= 32
529+
C = CartesianIndices(S)
530+
t = Tuple(S)
531+
for (linind, CI) in enumerate(C)
532+
i, j = Tuple(CI)
533+
if j-i < k
534+
t = Base.setindex(t, zero(t[linind]), linind)
535+
end
536+
end
537+
similar_type(S)(t)
538+
else
539+
M = triu!(copyto!(similar(S), S), k)
540+
similar_type(S)(M)
541+
end
542+
end
543+
function tril(S::StaticMatrix, k::Int=0)
544+
if length(S) <= 32
545+
C = CartesianIndices(S)
546+
t = Tuple(S)
547+
for (linind, CI) in enumerate(C)
548+
i, j = Tuple(CI)
549+
if j-i > k
550+
t = Base.setindex(t, zero(t[linind]), linind)
551+
end
552+
end
553+
similar_type(S)(t)
554+
else
555+
M = tril!(copyto!(similar(S), S), k)
556+
similar_type(S)(M)
557+
end
558+
end

Diff for: test/linalg.jl

+10
Original file line numberDiff line numberDiff line change
@@ -471,4 +471,14 @@ end
471471
m23 = SA[1 2 3; 4 5 6]
472472
@test_inlined checksquare(m23) false
473473
end
474+
475+
@testset "triu/tril" begin
476+
for S in (SMatrix{7,5}(1:35), MMatrix{4,6}(1:24), SizedArray{Tuple{2,2}}([1 2; 3 4]))
477+
M = Matrix(S)
478+
for k in -10:10
479+
@test triu(S, k) == triu(M, k)
480+
@test tril(S, k) == tril(M, k)
481+
end
482+
end
483+
end
474484
end

0 commit comments

Comments
 (0)