Skip to content

Commit 5f6b4da

Browse files
authored
improve support for custom Number types (#50)
1 parent 1d4cbb7 commit 5f6b4da

File tree

4 files changed

+56
-38
lines changed

4 files changed

+56
-38
lines changed

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "InverseFunctions"
22
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
3-
version = "0.1.14"
3+
version = "0.1.15"
44

55
[deps]
66
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
@@ -18,6 +18,7 @@ julia = "1"
1818
[extras]
1919
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
2020
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
21+
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2122

2223
[targets]
23-
test = ["Dates", "Documenter"]
24+
test = ["Dates", "Documenter", "Unitful"]

src/functions.jl

+25-28
Original file line numberDiff line numberDiff line change
@@ -5,57 +5,46 @@
55
66
Inverse of `sqrt(x)` for non-negative `x`.
77
"""
8-
square(x) = x^2
9-
function square(x::Real)
10-
x < zero(x) && throw(DomainError(x, "`square` is defined as the inverse of `sqrt` and can only be evaluated for non-negative values"))
8+
function square(x)
9+
if is_real_type(typeof(x)) && x < zero(x)
10+
throw(DomainError(x, "`square` is defined as the inverse of `sqrt` and can only be evaluated for non-negative values"))
11+
end
1112
return x^2
1213
end
1314

1415

15-
function invpow2(x::Real, p::Integer)
16-
if x zero(x) || isodd(p)
17-
copysign(abs(x)^inv(p), x)
16+
function invpow_arg2(x::Number, p::Real)
17+
if is_real_type(typeof(x))
18+
x zero(x) ? x^inv(p) : # x > 0 - trivially invertible
19+
isinteger(p) && isodd(Integer(p)) ? copysign(abs(x)^inv(p), x) : # p odd - invertible even for x < 0
20+
throw(DomainError(x, "inverse for x^$p is not defined at $x"))
1821
else
19-
throw(DomainError(x, "inverse for x^$p is not defined at $x"))
20-
end
21-
end
22-
function invpow2(x::Real, p::Real)
23-
if x zero(x)
24-
x^inv(p)
25-
else
26-
throw(DomainError(x, "inverse for x^$p is not defined at $x"))
27-
end
28-
end
29-
function invpow2(x, p::Real)
30-
# complex x^p is only invertible for p = 1/n
31-
if isinteger(inv(p))
32-
x^inv(p)
33-
else
34-
throw(DomainError(x, "inverse for x^$p is not defined at $x"))
22+
# complex x^p is invertible only for p = 1/n
23+
isinteger(inv(p)) ? x^inv(p) : throw(DomainError(x, "inverse for x^$p is not defined at $x"))
3524
end
3625
end
3726

38-
function invpow1(b::Real, x::Real)
27+
function invpow_arg1(b::Real, x::Real)
3928
if b zero(b) && x zero(x)
4029
log(b, x)
4130
else
4231
throw(DomainError(x, "inverse for $b^x is not defined at $x"))
4332
end
4433
end
4534

46-
function invlog1(b::Real, x::Real)
35+
function invlog_arg1(b::Real, x::Real)
4736
if b zero(b)
4837
b^x
4938
else
5039
throw(DomainError(x, "inverse for log($b, x) is not defined at $x"))
5140
end
5241
end
53-
invlog1(b, x) = b^x
42+
invlog_arg1(b::Number, x::Number) = b^x
5443

55-
invlog2(b, x) = x^inv(b)
44+
invlog_arg2(b::Number, x::Number) = x^inv(b)
5645

5746

58-
function invdivrem((q, r), divisor)
47+
function invdivrem((q, r)::NTuple{2,Number}, divisor::Number)
5948
res = muladd(q, divisor, r)
6049
if abs(r) abs(divisor) && (iszero(r) || sign(r) == sign(res))
6150
res
@@ -64,10 +53,18 @@ function invdivrem((q, r), divisor)
6453
end
6554
end
6655

67-
function invfldmod((q, r), divisor)
56+
function invfldmod((q, r)::NTuple{2,Number}, divisor::Number)
6857
if abs(r) abs(divisor) && (iszero(r) || sign(r) == sign(divisor))
6958
muladd(q, divisor, r)
7059
else
7160
throw(DomainError((q, r), "inverse for fldmod(x) is not defined at this point"))
7261
end
7362
end
63+
64+
65+
# check if T is a real-Number type
66+
# this is not the same as T <: Real which immediately excludes custom Number subtypes such as unitful numbers
67+
# also, isreal(x) != is_real_type(typeof(x)): the former is true for complex numbers with zero imaginary part
68+
@inline is_real_type(@nospecialize _::Type{<:Real}) = true
69+
@inline is_real_type(::Type{T}) where {T<:Number} = real(T) == T
70+
@inline is_real_type(@nospecialize _::Type) = false

src/inverse.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,14 @@ inverse(::typeof(sqrt)) = square
159159
inverse(::typeof(square)) = sqrt
160160

161161
inverse(::typeof(cbrt)) = Base.Fix2(^, 3)
162-
inverse(f::Base.Fix2{typeof(^)}) = iszero(f.x) ? throw(DomainError(f.x, "Cannot invert x^$(f.x)")) : Base.Fix2(invpow2, f.x)
163-
inverse(f::Base.Fix2{typeof(invpow2)}) = Base.Fix2(^, f.x)
164-
inverse(f::Base.Fix1{typeof(^)}) = Base.Fix1(invpow1, f.x)
165-
inverse(f::Base.Fix1{typeof(invpow1)}) = Base.Fix1(^, f.x)
166-
inverse(f::Base.Fix1{typeof(log)}) = Base.Fix1(invlog1, f.x)
167-
inverse(f::Base.Fix1{typeof(invlog1)}) = Base.Fix1(log, f.x)
168-
inverse(f::Base.Fix2{typeof(log)}) = Base.Fix2(invlog2, f.x)
169-
inverse(f::Base.Fix2{typeof(invlog2)}) = Base.Fix2(log, f.x)
162+
inverse(f::Base.Fix2{typeof(^)}) = iszero(f.x) ? throw(DomainError(f.x, "Cannot invert x^$(f.x)")) : Base.Fix2(invpow_arg2, f.x)
163+
inverse(f::Base.Fix2{typeof(invpow_arg2)}) = Base.Fix2(^, f.x)
164+
inverse(f::Base.Fix1{typeof(^)}) = Base.Fix1(invpow_arg1, f.x)
165+
inverse(f::Base.Fix1{typeof(invpow_arg1)}) = Base.Fix1(^, f.x)
166+
inverse(f::Base.Fix1{typeof(log)}) = Base.Fix1(invlog_arg1, f.x)
167+
inverse(f::Base.Fix1{typeof(invlog_arg1)}) = Base.Fix1(log, f.x)
168+
inverse(f::Base.Fix2{typeof(log)}) = Base.Fix2(invlog_arg2, f.x)
169+
inverse(f::Base.Fix2{typeof(invlog_arg2)}) = Base.Fix2(log, f.x)
170170

171171
inverse(f::Base.Fix2{typeof(divrem)}) = Base.Fix2(invdivrem, f.x)
172172
inverse(f::Base.Fix2{typeof(invdivrem)}) = Base.Fix2(divrem, f.x)

test/test_inverse.jl

+20
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
using Test
44
using InverseFunctions
5+
using Unitful
56
using Dates
67

78

@@ -79,14 +80,19 @@ end
7980
end
8081

8182
# ensure that inverses have domains compatible with original functions
83+
@test_throws DomainError inverse(sqrt)(-1.0)
84+
InverseFunctions.test_inverse(sqrt, complex(-1.0))
85+
InverseFunctions.test_inverse(sqrt, complex(1.0))
8286
@test_throws DomainError inverse(Base.Fix1(*, 0))
8387
@test_throws DomainError inverse(Base.Fix2(^, 0))
8488
@test_throws DomainError inverse(Base.Fix1(log, -2))(5)
8589
@test_throws DomainError inverse(Base.Fix1(log, 2))(-5)
8690
InverseFunctions.test_inverse(inverse(Base.Fix1(log, 2)), complex(-5))
8791
@test_throws DomainError inverse(Base.Fix2(^, 0.5))(-5)
8892
@test_throws DomainError inverse(Base.Fix2(^, 0.51))(complex(-5))
93+
@test_throws DomainError inverse(Base.Fix2(^, 2))(complex(-5))
8994
InverseFunctions.test_inverse(Base.Fix2(^, 0.5), complex(-5))
95+
InverseFunctions.test_inverse(Base.Fix2(^, -1), complex(-5.))
9096
@test_throws DomainError inverse(Base.Fix2(^, 2))(-5)
9197
@test_throws DomainError inverse(Base.Fix1(^, 2))(-5)
9298
@test_throws DomainError inverse(Base.Fix1(^, -2))(3)
@@ -130,6 +136,20 @@ end
130136
end
131137
end
132138

139+
@testset "unitful" begin
140+
# the majority of inverse just propagate to underlying mathematical functions and don't have any issues with unitful numbers
141+
# only those that behave treat real numbers differently have to be tested here
142+
x = rand()u"m"
143+
InverseFunctions.test_inverse(sqrt, x)
144+
@test_throws DomainError inverse(sqrt)(-x)
145+
146+
InverseFunctions.test_inverse(Base.Fix2(^, 2), x)
147+
@test_throws DomainError inverse(Base.Fix2(^, 2))(-x)
148+
InverseFunctions.test_inverse(Base.Fix2(^, 3), x)
149+
InverseFunctions.test_inverse(Base.Fix2(^, 3), -x)
150+
InverseFunctions.test_inverse(Base.Fix2(^, -3.5), x)
151+
end
152+
133153
@testset "dates" begin
134154
InverseFunctions.test_inverse(Dates.date2epochdays, Date(2020, 1, 2); compare = ===)
135155
InverseFunctions.test_inverse(Dates.datetime2epochms, DateTime(2020, 1, 2, 12, 34, 56); compare = ===)

0 commit comments

Comments
 (0)