Skip to content

Commit 19ecce7

Browse files
authored
Fix exp of SMatrix with large entries #1295 (#1296)
Fix #1295 Appreciate if you can review this @mikmoore.
1 parent 1af9ba6 commit 19ecce7

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.9.12"
3+
version = "1.9.13"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/expm.jl

+35-9
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ end
1414
(newtype)((exp(A[1]), ))
1515
end
1616

17+
# Bernstein, D. S. & So, W. 1993. "Some Explicit Formulas for the Matrix Exponential"
1718
@inline function _exp(::Size{(2,2)}, A::StaticMatrix{<:Any,<:Any,<:Real})
1819
T = typeof(exp(zero(eltype(A))))
1920
newtype = similar_type(A,T)
@@ -25,25 +26,50 @@ end
2526

2627
v = (a-d)^2 + 4*b*c
2728

29+
m = (a + d) / 2
30+
2831
if v > 0
29-
z = sqrt(v)
30-
z1 = cosh(z / 2)
31-
z2 = sinh(z / 2) / z
32+
# In this case the formulas in the entries of the matrix
33+
# are a function of cosh and sinh, and could (in theory)
34+
# follow the same code pattern of the other branches (v ≤ 0).
35+
# However, cosh and sinh explode with large arguments and
36+
# we use the following identity to avoid numerical issues:
37+
#
38+
# exp(m) * [c₁ * cohs(δ) + c₂ * sinh(δ)] =
39+
# c₁ * (e₊ + e₋) / 2 + c₂ * (e₊ - e₋) / 2
40+
#
41+
# where e₊ = exp(m + δ) and e₋ = exp(m - δ).
42+
#
43+
# See https://github.com/JuliaArrays/StaticArrays.jl/issues/1295
44+
δ = sqrt(v) / 2
45+
e₊ = exp(m + δ)
46+
e₋ = exp(m - δ)
47+
e₁ = (e₊ + e₋) / 2
48+
e₂ = (e₊ - e₋) / 2
49+
c₂ = (a - d) / 2δ
50+
m11 = (e₁ + c₂ * e₂)
51+
m12 = (b / δ) * e₂
52+
m21 = (c / δ) * e₂
53+
m22 = (e₁ - c₂ * e₂)
3254
elseif v < 0
3355
z = sqrt(-v)
56+
r = exp(m)
3457
z1 = cos(z / 2)
3558
z2 = sin(z / 2) / z
59+
m11 = r * (z1 + (a - d) * z2)
60+
m12 = r * 2b * z2
61+
m21 = r * 2c * z2
62+
m22 = r * (z1 - (a - d) * z2)
3663
else # if v == 0
64+
r = exp(m)
3765
z1 = T(1.0)
3866
z2 = T(0.5)
67+
m11 = r * (z1 + (a - d) * z2)
68+
m12 = r * 2b * z2
69+
m21 = r * 2c * z2
70+
m22 = r * (z1 - (a - d) * z2)
3971
end
4072

41-
r = exp((a + d) / 2)
42-
m11 = r * (z1 + (a - d) * z2)
43-
m12 = r * 2b * z2
44-
m21 = r * 2c * z2
45-
m22 = r * (z1 - (a - d) * z2)
46-
4773
(newtype)((m11, m21, m12, m22))
4874
end
4975

test/expm.jl

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ using StaticArrays, Test, LinearAlgebra
1313
@test exp(@SMatrix zeros(Complex{Float64}, 2, 2))::SMatrix Complex{Float64}[1 0; 0 1]
1414
@test exp(@SMatrix [1 2 0; 2 1 0; 0 0 1]) exp([1 2 0; 2 1 0; 0 0 1])
1515

16+
# https://github.com/JuliaArrays/StaticArrays.jl/issues/1295
17+
@test exp(@SMatrix [-800.0 800.0; 800.0 -800.0])::SMatrix [0.5 0.5; 0.5 0.5]
18+
@test exp(@SMatrix [-800.0 800.0; 800.0 -800.0]) exp([-800.0 800.0; 800.0 -800.0])
19+
1620
for sz in (3,4), typ in (Float64, Complex{Float64})
1721
A = rand(typ, sz, sz)
1822
nA = norm(A, 1)

0 commit comments

Comments
 (0)