Skip to content

Commit 74a4d01

Browse files
committed
make TaskLocal the default RNG
1 parent 417e52e commit 74a4d01

File tree

3 files changed

+29
-46
lines changed

3 files changed

+29
-46
lines changed

stdlib/Random/src/RNGs.jl

+10-22
Original file line numberDiff line numberDiff line change
@@ -359,45 +359,33 @@ function seed!(r::MersenneTwister, seed::Vector{UInt32})
359359
return r
360360
end
361361

362-
seed!(r::MersenneTwister=default_rng()) = seed!(r, make_seed())
362+
seed!() = seed!(default_rng(), make_seed())
363+
seed!(r::MersenneTwister) = seed!(r, make_seed())
363364
seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
364-
seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(default_rng(), seed)
365+
seed!(seed::Union{Integer,Vector{UInt32},Vector{UInt64}}) = seed!(default_rng(), seed)
365366

366367

367368
### Global RNG
368369

369-
const THREAD_RNGs = MersenneTwister[]
370-
@inline default_rng() = default_rng(Threads.threadid())
371-
@noinline function default_rng(tid::Int)
372-
0 < tid <= length(THREAD_RNGs) || _rng_length_assert()
373-
if @inbounds isassigned(THREAD_RNGs, tid)
374-
@inbounds MT = THREAD_RNGs[tid]
375-
else
376-
MT = MersenneTwister()
377-
@inbounds THREAD_RNGs[tid] = MT
378-
end
379-
return MT
380-
end
381-
@noinline _rng_length_assert() = @assert false "0 < tid <= length(THREAD_RNGs)"
370+
@inline default_rng() = TaskLocalRNG()
371+
@inline default_rng(tid::Int) = TaskLocalRNG()
382372

383373
function __init__()
384-
resize!(empty!(THREAD_RNGs), Threads.nthreads()) # ensures that we didn't save a bad object
385374
seed!(TaskLocalRNG())
386375
end
387376

388-
389377
struct _GLOBAL_RNG <: AbstractRNG
390378
global const GLOBAL_RNG = _GLOBAL_RNG.instance
391379
end
392380

393-
# GLOBAL_RNG currently represents a MersenneTwister
394-
typeof_rng(::_GLOBAL_RNG) = MersenneTwister
381+
# GLOBAL_RNG currently uses TaskLocalRNG
382+
typeof_rng(::_GLOBAL_RNG) = TaskLocalRNG
395383

396-
copy!(dst::MersenneTwister, ::_GLOBAL_RNG) = copy!(dst, default_rng())
397-
copy!(::_GLOBAL_RNG, src::MersenneTwister) = copy!(default_rng(), src)
384+
copy!(dst::Xoshiro, ::_GLOBAL_RNG) = copy!(dst, default_rng())
385+
copy!(::_GLOBAL_RNG, src::Xoshiro) = copy!(default_rng(), src)
398386
copy(::_GLOBAL_RNG) = copy(default_rng())
399387

400-
seed!(::_GLOBAL_RNG, seed::Vector{UInt32}) = seed!(default_rng(), seed)
388+
seed!(::_GLOBAL_RNG, seed::Union{Vector{UInt32}, Vector{UInt64}}) = seed!(default_rng(), seed)
401389
seed!(::_GLOBAL_RNG, n::Integer) = seed!(default_rng(), n)
402390
seed!(::_GLOBAL_RNG, ::Nothing) = seed!(default_rng(), nothing)
403391
seed!(::_GLOBAL_RNG) = seed!(default_rng(), nothing)

stdlib/Random/test/runtests.jl

+18-20
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ guardseed() do
627627
m = MersenneTwister(0)
628628
@test Random.seed!() === g
629629
@test Random.seed!(rand(UInt)) === g
630-
@test Random.seed!(rand(UInt32, rand(1:10))) === g
630+
@test Random.seed!(rand(UInt32, rand(1:8))) === g
631631
@test Random.seed!(m) === m
632632
@test Random.seed!(m, rand(UInt)) === m
633633
@test Random.seed!(m, rand(UInt32, rand(1:10))) === m
@@ -751,28 +751,26 @@ end
751751
@test Random.seed!(GLOBAL_RNG, 0) === LOCAL_RNG
752752
@test Random.seed!(GLOBAL_RNG) === LOCAL_RNG
753753

754-
mt = MersenneTwister(1)
755-
@test copy!(mt, GLOBAL_RNG) === mt
756-
@test mt == LOCAL_RNG
757-
Random.seed!(mt, 2)
758-
@test mt != LOCAL_RNG
759-
@test copy!(GLOBAL_RNG, mt) === LOCAL_RNG
760-
@test mt == LOCAL_RNG
761-
mt2 = copy(GLOBAL_RNG)
762-
@test mt2 isa typeof(LOCAL_RNG)
763-
@test mt2 !== LOCAL_RNG
764-
@test mt2 == LOCAL_RNG
754+
xo = Xoshiro()
755+
@test copy!(xo, GLOBAL_RNG) === xo
756+
@test xo == LOCAL_RNG
757+
Random.seed!(xo, 2)
758+
@test xo != LOCAL_RNG
759+
@test copy!(GLOBAL_RNG, xo) === LOCAL_RNG
760+
@test xo == LOCAL_RNG
761+
xo2 = copy(GLOBAL_RNG)
762+
@test xo2 !== LOCAL_RNG
763+
@test xo2 == LOCAL_RNG
765764

766765
for T in (Random.UInt52Raw{UInt64},
767-
Random.UInt2x52Raw{UInt128},
768766
Random.UInt104Raw{UInt128},
769767
Random.CloseOpen12_64)
770768
x = Random.SamplerTrivial(T())
771-
@test rand(GLOBAL_RNG, x) === rand(mt, x)
769+
@test rand(GLOBAL_RNG, x) === rand(xo, x)
772770
end
773771
for T in (Int64, UInt64, Int128, UInt128, Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)
774772
x = Random.SamplerType{T}()
775-
@test rand(GLOBAL_RNG, x) === rand(mt, x)
773+
@test rand(GLOBAL_RNG, x) === rand(xo, x)
776774
end
777775

778776
A = fill(0.0, 100, 100)
@@ -781,21 +779,21 @@ end
781779
vB = view(B, :, :)
782780
I1 = Random.SamplerTrivial(Random.CloseOpen01{Float64}())
783781
I2 = Random.SamplerTrivial(Random.CloseOpen12{Float64}())
784-
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
782+
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(xo, B, I1) === B
785783
B = fill!(B, 1.0)
786784
@test rand!(GLOBAL_RNG, vA, I1) === vA
787-
rand!(mt, vB, I1)
785+
rand!(xo, vB, I1)
788786
@test A == B
789787
for T in (Float16, Float32)
790788
B = fill!(B, 1.0)
791-
@test rand!(GLOBAL_RNG, A, I2) === A == rand!(mt, B, I2) === B
789+
@test rand!(GLOBAL_RNG, A, I2) === A == rand!(xo, B, I2) === B
792790
B = fill!(B, 1.0)
793-
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
791+
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(xo, B, I1) === B
794792
end
795793
for T in Base.BitInteger_types
796794
x = Random.SamplerType{T}()
797795
B = fill!(B, 1.0)
798-
@test rand!(GLOBAL_RNG, A, x) === A == rand!(mt, B, x) === B
796+
@test rand!(GLOBAL_RNG, A, x) === A == rand!(xo, B, x) === B
799797
end
800798
# issue #33170
801799
@test Sampler(GLOBAL_RNG, 2:4, Val(1)) isa SamplerRangeNDL

stdlib/Test/src/Test.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -1227,8 +1227,6 @@ function testset_beginend(args, tests, source)
12271227
local RNG = default_rng()
12281228
local oldrng = copy(RNG)
12291229
try
1230-
# RNG is re-seeded with its own seed to ease reproduce a failed test
1231-
Random.seed!(RNG.seed)
12321230
let
12331231
$(esc(tests))
12341232
end
@@ -1319,7 +1317,6 @@ function testset_forloop(args, testloop, source)
13191317
local ts
13201318
local RNG = default_rng()
13211319
local oldrng = copy(RNG)
1322-
Random.seed!(RNG.seed)
13231320
local tmprng = copy(RNG)
13241321
try
13251322
let
@@ -1790,7 +1787,7 @@ end
17901787

17911788
"`guardseed(f, seed)` is equivalent to running `Random.seed!(seed); f()` and
17921789
then restoring the state of the global RNG as it was before."
1793-
guardseed(f::Function, seed::Union{Vector{UInt32},Integer}) = guardseed() do
1790+
guardseed(f::Function, seed::Union{Vector{UInt64},Vector{UInt32},Integer}) = guardseed() do
17941791
Random.seed!(seed)
17951792
f()
17961793
end

0 commit comments

Comments
 (0)