Skip to content

Commit 244ada3

Browse files
authored
fallback randn/randexp for AbstractFloat (#44714)
* fallback randn/randexp for AbstractFloat
1 parent f536b81 commit 244ada3

File tree

4 files changed

+47
-11
lines changed

4 files changed

+47
-11
lines changed

NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ Standard library changes
6565

6666
#### Random
6767

68+
* `randn` and `randexp` now work for any `AbstractFloat` type defining `rand` ([#44714]).
69+
6870
#### REPL
6971

7072
#### SparseArrays

stdlib/Random/src/Random.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,10 @@ Sampler(rng::AbstractRNG, ::Type{X}, r::Repetition=Val(Inf)) where {X} =
143143

144144
typeof_rng(rng::AbstractRNG) = typeof(rng)
145145

146-
Sampler(::Type{<:AbstractRNG}, sp::Sampler, ::Repetition) =
147-
throw(ArgumentError("Sampler for this object is not defined"))
146+
# this method is necessary to prevent rand(rng::AbstractRNG, X) from
147+
# recursively constructing nested Sampler types.
148+
Sampler(T::Type{<:AbstractRNG}, sp::Sampler, r::Repetition) =
149+
throw(MethodError(Sampler, (T, sp, r)))
148150

149151
# default shortcut for the general case
150152
Sampler(::Type{RNG}, X) where {RNG<:AbstractRNG} = Sampler(RNG, X, Val(Inf))

stdlib/Random/src/normal.jl

+12
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,15 @@ randn(rng::AbstractRNG, ::Type{Complex{T}}) where {T<:AbstractFloat} =
9090
Complex{T}(SQRT_HALF * randn(rng, T), SQRT_HALF * randn(rng, T))
9191

9292

93+
### fallback randn for float types defining rand:
94+
function randn(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat}
95+
# Marsaglia polar variant of Box–Muller transform:
96+
while true
97+
x, y = 2rand(rng, T)-1, 2rand(rng, T)-1
98+
0 < (s = x^2 + y^2) < 1 && return x * sqrt(-2log(s)/s)
99+
end
100+
end
101+
93102
## randexp
94103

95104
"""
@@ -137,6 +146,9 @@ end
137146
end
138147
end
139148

149+
### fallback randexp for float types defining rand:
150+
randexp(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat} =
151+
-log1p(-rand(rng, T))
140152

141153
## arrays & other scalar methods
142154

stdlib/Random/test/runtests.jl

+29-9
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ let A = zeros(2, 2)
4747
0.9103565379264364 0.17732884646626457]
4848
end
4949
let A = zeros(2, 2)
50-
@test_throws ArgumentError rand!(MersenneTwister(0), A, 5)
50+
@test_throws MethodError rand!(MersenneTwister(0), A, 5)
5151
@test rand(MersenneTwister(0), Int64, 1) == [-3433174948434291912]
5252
end
5353
let A = zeros(Int64, 2, 2)
@@ -307,9 +307,32 @@ let a = [rand(RandomDevice(), UInt128) for i=1:10]
307307
@test reduce(|, a)>>>64 != 0
308308
end
309309

310+
# wrapper around Float64 to check fallback random generators
311+
struct FakeFloat64 <: AbstractFloat
312+
x::Float64
313+
end
314+
Base.rand(rng::AbstractRNG, ::Random.SamplerTrivial{Random.CloseOpen01{FakeFloat64}}) = FakeFloat64(rand(rng))
315+
for f in (:sqrt, :log, :log1p, :one, :zero, :abs, :+, :-)
316+
@eval Base.$f(x::FakeFloat64) = FakeFloat64($f(x.x))
317+
end
318+
for f in (:+, :-, :*, :/)
319+
@eval begin
320+
Base.$f(x::FakeFloat64, y::FakeFloat64) = FakeFloat64($f(x.x,y.x))
321+
Base.$f(x::FakeFloat64, y::Real) = FakeFloat64($f(x.x,y))
322+
Base.$f(x::Real, y::FakeFloat64) = FakeFloat64($f(x,y.x))
323+
end
324+
end
325+
for f in (:<, :<=, :>, :>=, :(==), :(!=))
326+
@eval begin
327+
Base.$f(x::FakeFloat64, y::FakeFloat64) = $f(x.x,y.x)
328+
Base.$f(x::FakeFloat64, y::Real) = $f(x.x,y)
329+
Base.$f(x::Real, y::FakeFloat64) = $f(x,y.x)
330+
end
331+
end
332+
310333
# test all rand APIs
311334
for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
312-
ftypes = [Float16, Float32, Float64]
335+
ftypes = [Float16, Float32, Float64, FakeFloat64, BigFloat]
313336
cftypes = [ComplexF16, ComplexF32, ComplexF64, ftypes...]
314337
types = [Bool, Char, BigFloat, Base.BitInteger_types..., ftypes...]
315338
randset = Set(rand(Int, 20))
@@ -406,15 +429,12 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
406429
rand!(rng..., BitMatrix(undef, 2, 3)) ::BitArray{2}
407430

408431
# Test that you cannot call randn or randexp with non-Float types.
409-
for r in [randn, randexp, randn!, randexp!]
410-
local r
432+
for r in [randn, randexp]
411433
@test_throws MethodError r(Int)
412434
@test_throws MethodError r(Int32)
413435
@test_throws MethodError r(Bool)
414436
@test_throws MethodError r(String)
415437
@test_throws MethodError r(AbstractFloat)
416-
# TODO(#17627): Consider adding support for randn(BigFloat) and removing this test.
417-
@test_throws MethodError r(BigFloat)
418438

419439
@test_throws MethodError r(Int64, (2,3))
420440
@test_throws MethodError r(String, 1)
@@ -664,7 +684,7 @@ let b = ['0':'9';'A':'Z';'a':'z']
664684
end
665685

666686
# this shouldn't crash (#22403)
667-
@test_throws ArgumentError rand!(Union{UInt,Int}[1, 2, 3])
687+
@test_throws MethodError rand!(Union{UInt,Int}[1, 2, 3])
668688

669689
@testset "$RNG() & Random.seed!(rng::$RNG) initializes randomly" for RNG in (MersenneTwister, RandomDevice, Xoshiro)
670690
m = RNG()
@@ -736,8 +756,8 @@ end
736756

737757
struct RandomStruct23964 end
738758
@testset "error message when rand not defined for a type" begin
739-
@test_throws ArgumentError rand(nothing)
740-
@test_throws ArgumentError rand(RandomStruct23964())
759+
@test_throws MethodError rand(nothing)
760+
@test_throws MethodError rand(RandomStruct23964())
741761
end
742762

743763
@testset "rand(::$(typeof(RNG)), ::UnitRange{$T}" for RNG (MersenneTwister(rand(UInt128)), RandomDevice(), Xoshiro()),

0 commit comments

Comments
 (0)