Skip to content

Commit b984cd5

Browse files
committed
Random,threads: allocate state at runtime for each thread
The `Random.GLOBAL_RNG` is now a singleton placeholder object which implements the prior `Random` public API for MersenneTwister as a shim to support existing clients until Julia v2.0.
1 parent 2964312 commit b984cd5

File tree

11 files changed

+191
-82
lines changed

11 files changed

+191
-82
lines changed

stdlib/Random/src/RNGs.jl

+61-9
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
## RandomDevice
44

5-
# SamplerUnion(Union{X,Y,...}) == Union{SamplerType{X},SamplerType{Y},...}
6-
SamplerUnion(U::Union) = Union{map(T->SamplerType{T}, Base.uniontypes(U))...}
7-
const SamplerBoolBitInteger = SamplerUnion(Union{Bool, BitInteger})
5+
# SamplerUnion(X, Y, ...}) == Union{SamplerType{X}, SamplerType{Y}, ...}
6+
SamplerUnion(U...) = Union{Any[SamplerType{T} for T in U]...}
7+
const SamplerBoolBitInteger = SamplerUnion(Bool, BitInteger_types...)
88

99
if Sys.iswindows()
1010
struct RandomDevice <: AbstractRNG
@@ -285,14 +285,66 @@ function seed!(r::MersenneTwister, seed::Vector{UInt32})
285285
return r
286286
end
287287

288-
seed!(r::MersenneTwister=GLOBAL_RNG) = seed!(r, make_seed())
288+
seed!(r::MersenneTwister=get_local_rng()) = seed!(r, make_seed())
289289
seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
290-
seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(GLOBAL_RNG, seed)
290+
seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(get_local_rng(), seed)
291291

292292

293-
### Global RNG (must be defined after seed!)
293+
### Global RNG
294294

295-
const GLOBAL_RNG = MersenneTwister(0)
295+
const THREAD_RNGs = MersenneTwister[]
296+
@inline get_local_rng() = get_local_rng(Threads.threadid())
297+
@noinline Base.@pure function get_local_rng(tid::Int)
298+
#tls = task_local_storage()
299+
#RNG = get(tls, :RNG, nothing)
300+
#RNG isa MersenneTwister && return RNG
301+
if length(THREAD_RNGs) < tid
302+
resize!(THREAD_RNGs, Threads.nthreads())
303+
end
304+
if @inbounds isassigned(THREAD_RNGs, tid)
305+
@inbounds MT = THREAD_RNGs[tid]
306+
else
307+
MT = MersenneTwister()
308+
@inbounds THREAD_RNGs[tid] = MT
309+
end
310+
return MT
311+
end
312+
__init__() = empty!(THREAD_RNGs) # ensures that we didn't save a bad object
313+
314+
315+
struct _GLOBAL_RNG <: AbstractRNG
316+
global const GLOBAL_RNG = _GLOBAL_RNG.instance
317+
end
318+
319+
copy!(dst::MersenneTwister, ::_GLOBAL_RNG) = copy!(dst, get_local_rng())
320+
copy!(::_GLOBAL_RNG, src::MersenneTwister) = copy!(get_local_rng(), src)
321+
copy(::_GLOBAL_RNG) = copy(get_local_rng())
322+
323+
seed!(::_GLOBAL_RNG, seed::Vector{UInt32}) = seed!(get_local_rng(), seed)
324+
seed!(::_GLOBAL_RNG, n::Integer) = seed!(get_local_rng(), n)
325+
seed!(::_GLOBAL_RNG, ::Nothing) = seed!(get_local_rng(), nothing)
326+
327+
rng_native_52(::_GLOBAL_RNG) = rng_native_52(get_local_rng())
328+
rand(::_GLOBAL_RNG, sp::SamplerBoolBitInteger) = rand(get_local_rng(), sp)
329+
for T in (:(SamplerTrivial{UInt52Raw{UInt64}}),
330+
:(SamplerTrivial{UInt2x52Raw{UInt128}}),
331+
:(SamplerTrivial{UInt104Raw{UInt128}}),
332+
:(SamplerTrivial{CloseOpen12_64}),
333+
:(SamplerUnion(Int64, UInt64, Int128, UInt128)),
334+
:(SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)),
335+
)
336+
@eval rand(::_GLOBAL_RNG, x::$T) = rand(get_local_rng(), x)
337+
end
338+
339+
rand!(::_GLOBAL_RNG, A::AbstractArray{Float64}, I::SamplerTrivial{<:FloatInterval_64}) = rand!(get_local_rng(), A, I)
340+
rand!(::_GLOBAL_RNG, A::Array{Float64}, I::SamplerTrivial{<:FloatInterval_64}) = rand!(get_local_rng(), A, I)
341+
for T in (Float16, Float32)
342+
@eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerTrivial{CloseOpen12{$T}}) = rand!(get_local_rng(), A, I)
343+
@eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerTrivial{CloseOpen01{$T}}) = rand!(get_local_rng(), A, I)
344+
end
345+
for T in BitInteger_types
346+
@eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerType{$T}) = rand!(get_local_rng(), A, I)
347+
end
296348

297349

298350
### generation
@@ -332,10 +384,10 @@ rand(r::MersenneTwister, sp::SamplerTrivial{CloseOpen12_64}) =
332384

333385
#### integers
334386

335-
rand(r::MersenneTwister, T::SamplerUnion(Union{Int64,UInt64,Int128,UInt128})) =
387+
rand(r::MersenneTwister, T::SamplerUnion(Int64, UInt64, Int128, UInt128)) =
336388
mt_pop!(r, T[])
337389

338-
rand(r::MersenneTwister, T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
390+
rand(r::MersenneTwister, T::SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)) =
339391
rand(r, UInt52Raw()) % T[]
340392

341393
#### arrays of floats

stdlib/Random/src/Random.jl

+10-21
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ abstract type Sampler{E} end
108108
gentype(::Type{<:Sampler{E}}) where {E} = E
109109

110110
# temporarily for BaseBenchmarks
111-
RangeGenerator(x) = Sampler(GLOBAL_RNG, x)
111+
RangeGenerator(x) = Sampler(get_local_rng(), x)
112112

113113
# In some cases, when only 1 random value is to be generated,
114114
# the optimal sampler can be different than if multiple values
@@ -247,18 +247,18 @@ rand(rng::AbstractRNG, ::UniformT{T}) where {T} = rand(rng, T)
247247

248248
#### scalars
249249

250-
rand(rng::AbstractRNG, X) = rand(rng, Sampler(rng, X, Val(1)))
250+
rand(rng::AbstractRNG, X) = rand(rng, Sampler(rng, X, Val(1)))
251251
# this is needed to disambiguate
252-
rand(rng::AbstractRNG, X::Dims) = rand(rng, Sampler(rng, X, Val(1)))
253-
rand(rng::AbstractRNG=GLOBAL_RNG, ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))
252+
rand(rng::AbstractRNG, X::Dims) = rand(rng, Sampler(rng, X, Val(1)))
253+
rand(rng::AbstractRNG=get_local_rng(), ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))
254254

255-
rand(X) = rand(GLOBAL_RNG, X)
256-
rand(::Type{X}) where {X} = rand(GLOBAL_RNG, X)
255+
rand(X) = rand(get_local_rng(), X)
256+
rand(::Type{X}) where {X} = rand(get_local_rng(), X)
257257

258258
#### arrays
259259

260-
rand!(A::AbstractArray{T}, X) where {T} = rand!(GLOBAL_RNG, A, X)
261-
rand!(A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(GLOBAL_RNG, A, X)
260+
rand!(A::AbstractArray{T}, X) where {T} = rand!(get_local_rng(), A, X)
261+
rand!(A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(get_local_rng(), A, X)
262262

263263
rand!(rng::AbstractRNG, A::AbstractArray{T}, X) where {T} = rand!(rng, A, Sampler(rng, X))
264264
rand!(rng::AbstractRNG, A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(rng, A, Sampler(rng, X))
@@ -274,7 +274,7 @@ rand(r::AbstractRNG, dims::Integer...) = rand(r, Float64, Dims(dims))
274274
rand( dims::Integer...) = rand(Float64, Dims(dims))
275275

276276
rand(r::AbstractRNG, X, dims::Dims) = rand!(r, Array{gentype(X)}(undef, dims), X)
277-
rand( X, dims::Dims) = rand(GLOBAL_RNG, X, dims)
277+
rand( X, dims::Dims) = rand(get_local_rng(), X, dims)
278278

279279
rand(r::AbstractRNG, X, d::Integer, dims::Integer...) = rand(r, X, Dims((d, dims...)))
280280
rand( X, d::Integer, dims::Integer...) = rand(X, Dims((d, dims...)))
@@ -283,23 +283,12 @@ rand( X, d::Integer, dims::Integer...) = rand(X, Dims((d, dims...
283283
# moreover, a call like rand(r, NotImplementedType()) would be an infinite loop
284284

285285
rand(r::AbstractRNG, ::Type{X}, dims::Dims) where {X} = rand!(r, Array{X}(undef, dims), X)
286-
rand( ::Type{X}, dims::Dims) where {X} = rand(GLOBAL_RNG, X, dims)
286+
rand( ::Type{X}, dims::Dims) where {X} = rand(get_local_rng(), X, dims)
287287

288288
rand(r::AbstractRNG, ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(r, X, Dims((d, dims...)))
289289
rand( ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(X, Dims((d, dims...)))
290290

291291

292-
## __init__ & include
293-
294-
function __init__()
295-
try
296-
seed!()
297-
catch ex
298-
Base.showerror_nostdio(ex,
299-
"WARNING: Error during initialization of module Random")
300-
end
301-
end
302-
303292
include("RNGs.jl")
304293
include("generation.jl")
305294
include("normal.jl")

stdlib/Random/src/misc.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ let b = UInt8['0':'9';'A':'Z';'a':'z']
7373
global randstring
7474
randstring(r::AbstractRNG, chars=b, n::Integer=8) = String(rand(r, chars, n))
7575
randstring(r::AbstractRNG, n::Integer) = randstring(r, b, n)
76-
randstring(chars=b, n::Integer=8) = randstring(GLOBAL_RNG, chars, n)
77-
randstring(n::Integer) = randstring(GLOBAL_RNG, b, n)
76+
randstring(chars=b, n::Integer=8) = randstring(get_local_rng(), chars, n)
77+
randstring(n::Integer) = randstring(get_local_rng(), b, n)
7878
end
7979

8080

@@ -140,7 +140,7 @@ julia> S
140140
8
141141
```
142142
"""
143-
randsubseq!(S::AbstractArray, A::AbstractArray, p::Real) = randsubseq!(GLOBAL_RNG, S, A, p)
143+
randsubseq!(S::AbstractArray, A::AbstractArray, p::Real) = randsubseq!(get_local_rng(), S, A, p)
144144

145145
randsubseq(r::AbstractRNG, A::AbstractArray{T}, p::Real) where {T} =
146146
randsubseq!(r, T[], A, p)
@@ -163,7 +163,7 @@ julia> randsubseq(rng, collect(1:8), 0.3)
163163
8
164164
```
165165
"""
166-
randsubseq(A::AbstractArray, p::Real) = randsubseq(GLOBAL_RNG, A, p)
166+
randsubseq(A::AbstractArray, p::Real) = randsubseq(get_local_rng(), A, p)
167167

168168

169169
## rand Less Than Masked 52 bits (helper function)
@@ -217,7 +217,7 @@ function shuffle!(r::AbstractRNG, a::AbstractArray)
217217
return a
218218
end
219219

220-
shuffle!(a::AbstractArray) = shuffle!(GLOBAL_RNG, a)
220+
shuffle!(a::AbstractArray) = shuffle!(get_local_rng(), a)
221221

222222
"""
223223
shuffle([rng=GLOBAL_RNG,] v::AbstractArray)
@@ -246,7 +246,7 @@ julia> shuffle(rng, Vector(1:10))
246246
```
247247
"""
248248
shuffle(r::AbstractRNG, a::AbstractArray) = shuffle!(r, copymutable(a))
249-
shuffle(a::AbstractArray) = shuffle(GLOBAL_RNG, a)
249+
shuffle(a::AbstractArray) = shuffle(get_local_rng(), a)
250250

251251

252252
## randperm & randperm!
@@ -277,7 +277,7 @@ julia> randperm(MersenneTwister(1234), 4)
277277
```
278278
"""
279279
randperm(r::AbstractRNG, n::T) where {T <: Integer} = randperm!(r, Vector{T}(undef, n))
280-
randperm(n::Integer) = randperm(GLOBAL_RNG, n)
280+
randperm(n::Integer) = randperm(get_local_rng(), n)
281281

282282
"""
283283
randperm!([rng=GLOBAL_RNG,] A::Array{<:Integer})
@@ -314,7 +314,7 @@ function randperm!(r::AbstractRNG, a::Array{<:Integer})
314314
return a
315315
end
316316

317-
randperm!(a::Array{<:Integer}) = randperm!(GLOBAL_RNG, a)
317+
randperm!(a::Array{<:Integer}) = randperm!(get_local_rng(), a)
318318

319319

320320
## randcycle & randcycle!
@@ -343,7 +343,7 @@ julia> randcycle(MersenneTwister(1234), 6)
343343
```
344344
"""
345345
randcycle(r::AbstractRNG, n::T) where {T <: Integer} = randcycle!(r, Vector{T}(undef, n))
346-
randcycle(n::Integer) = randcycle(GLOBAL_RNG, n)
346+
randcycle(n::Integer) = randcycle(get_local_rng(), n)
347347

348348
"""
349349
randcycle!([rng=GLOBAL_RNG,] A::Array{<:Integer})
@@ -379,4 +379,4 @@ function randcycle!(r::AbstractRNG, a::Array{<:Integer})
379379
return a
380380
end
381381

382-
randcycle!(a::Array{<:Integer}) = randcycle!(GLOBAL_RNG, a)
382+
randcycle!(a::Array{<:Integer}) = randcycle!(get_local_rng(), a)

stdlib/Random/src/normal.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ julia> randn(rng, ComplexF32, (2, 3))
3535
0.611224+1.56403im 0.355204-0.365563im 0.0905552+1.31012im
3636
```
3737
"""
38-
@inline function randn(rng::AbstractRNG=GLOBAL_RNG)
38+
@inline function randn(rng::AbstractRNG=get_local_rng())
3939
@inbounds begin
4040
r = rand(rng, UInt52())
4141
rabs = Int64(r>>1) # One bit for the sign
@@ -95,7 +95,7 @@ julia> randexp(rng, 3, 3)
9595
0.695867 0.693292 0.643644
9696
```
9797
"""
98-
function randexp(rng::AbstractRNG=GLOBAL_RNG)
98+
function randexp(rng::AbstractRNG=get_local_rng())
9999
@inbounds begin
100100
ri = rand(rng, UInt52())
101101
idx = ri & 0xFF
@@ -165,7 +165,7 @@ for randfun in [:randn, :randexp]
165165
@eval begin
166166
# scalars
167167
$randfun(rng::AbstractRNG, T::BitFloatType) = convert(T, $randfun(rng))
168-
$randfun(::Type{T}) where {T} = $randfun(GLOBAL_RNG, T)
168+
$randfun(::Type{T}) where {T} = $randfun(get_local_rng(), T)
169169

170170
# filling arrays
171171
function $randfun!(rng::AbstractRNG, A::AbstractArray{T}) where T
@@ -175,19 +175,19 @@ for randfun in [:randn, :randexp]
175175
A
176176
end
177177

178-
$randfun!(A::AbstractArray) = $randfun!(GLOBAL_RNG, A)
178+
$randfun!(A::AbstractArray) = $randfun!(get_local_rng(), A)
179179

180180
# generating arrays
181181
$randfun(rng::AbstractRNG, ::Type{T}, dims::Dims ) where {T} = $randfun!(rng, Array{T}(undef, dims))
182182
# Note that this method explicitly does not define $randfun(rng, T),
183183
# in order to prevent an infinite recursion.
184184
$randfun(rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...) where {T} = $randfun!(rng, Array{T}(undef, dim1, dims...))
185-
$randfun( ::Type{T}, dims::Dims ) where {T} = $randfun(GLOBAL_RNG, T, dims)
186-
$randfun( ::Type{T}, dims::Integer... ) where {T} = $randfun(GLOBAL_RNG, T, dims...)
185+
$randfun( ::Type{T}, dims::Dims ) where {T} = $randfun(get_local_rng(), T, dims)
186+
$randfun( ::Type{T}, dims::Integer... ) where {T} = $randfun(get_local_rng(), T, dims...)
187187
$randfun(rng::AbstractRNG, dims::Dims ) = $randfun(rng, Float64, dims)
188188
$randfun(rng::AbstractRNG, dims::Integer... ) = $randfun(rng, Float64, dims...)
189-
$randfun( dims::Dims ) = $randfun(GLOBAL_RNG, Float64, dims)
190-
$randfun( dims::Integer... ) = $randfun(GLOBAL_RNG, Float64, dims...)
189+
$randfun( dims::Dims ) = $randfun(get_local_rng(), Float64, dims)
190+
$randfun( dims::Integer... ) = $randfun(get_local_rng(), Float64, dims...)
191191
end
192192
end
193193

stdlib/Random/test/runtests.jl

+58-1
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ end
601601

602602
# Random.seed!(rng, ...) returns rng (#21248)
603603
guardseed() do
604-
g = Random.GLOBAL_RNG
604+
g = Random.get_local_rng()
605605
m = MersenneTwister(0)
606606
@test Random.seed!() === g
607607
@test Random.seed!(rand(UInt)) === g
@@ -713,3 +713,60 @@ end
713713
@test rand((x, 2, 3, 4, 6)) 1:6
714714
end
715715
end
716+
717+
@testset "GLOBAL_RNG" begin
718+
local GLOBAL_RNG = Random.GLOBAL_RNG
719+
local LOCAL_RNG = Random.get_local_rng()
720+
@test VERSION < v"2" # deprecate this in v2
721+
722+
@test Random.seed!(GLOBAL_RNG, nothing) === LOCAL_RNG
723+
@test Random.seed!(GLOBAL_RNG, UInt32[0]) === LOCAL_RNG
724+
@test Random.seed!(GLOBAL_RNG, 0) === LOCAL_RNG
725+
726+
mt = MersenneTwister(1)
727+
@test copy!(mt, GLOBAL_RNG) === mt
728+
@test mt == LOCAL_RNG
729+
Random.seed!(mt, 2)
730+
@test mt != LOCAL_RNG
731+
@test copy!(GLOBAL_RNG, mt) === LOCAL_RNG
732+
@test mt == LOCAL_RNG
733+
mt2 = copy(GLOBAL_RNG)
734+
@test mt2 isa typeof(LOCAL_RNG)
735+
@test mt2 !== LOCAL_RNG
736+
@test mt2 == LOCAL_RNG
737+
738+
for T in (Random.UInt52Raw{UInt64},
739+
Random.UInt2x52Raw{UInt128},
740+
Random.UInt104Raw{UInt128},
741+
Random.CloseOpen12_64)
742+
x = Random.SamplerTrivial(T())
743+
@test rand(GLOBAL_RNG, x) === rand(mt, x)
744+
end
745+
for T in (Int64, UInt64, Int128, UInt128, Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)
746+
x = Random.SamplerType{T}()
747+
@test rand(GLOBAL_RNG, x) === rand(mt, x)
748+
end
749+
750+
A = fill(0.0, 100, 100)
751+
B = fill(1.0, 100, 100)
752+
vA = view(A, :, :)
753+
vB = view(B, :, :)
754+
I1 = Random.SamplerTrivial(Random.CloseOpen01{Float64}())
755+
I2 = Random.SamplerTrivial(Random.CloseOpen12{Float64}())
756+
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
757+
B = fill!(B, 1.0)
758+
@test rand!(GLOBAL_RNG, vA, I1) === vA
759+
rand!(mt, vB, I1)
760+
@test A == B
761+
for T in (Float16, Float32)
762+
B = fill!(B, 1.0)
763+
@test rand!(GLOBAL_RNG, A, I2) === A == rand!(mt, B, I2) === B
764+
B = fill!(B, 1.0)
765+
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
766+
end
767+
for T in Base.BitInteger_types
768+
x = Random.SamplerType{T}()
769+
B = fill!(B, 1.0)
770+
@test rand!(GLOBAL_RNG, A, x) === A == rand!(mt, B, x) === B
771+
end
772+
end

stdlib/SparseArrays/src/SparseArrays.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
2929
rotl90, rotr90, round, setindex!, similar, size, transpose,
3030
vec, permute!, map, map!, Array, diff, circshift!, circshift
3131

32-
using Random: GLOBAL_RNG, AbstractRNG, randsubseq, randsubseq!
32+
using Random: get_local_rng, AbstractRNG, randsubseq, randsubseq!
3333

3434
export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
3535
SparseMatrixCSC, SparseVector, blockdiag, droptol!, dropzeros!, dropzeros,

0 commit comments

Comments
 (0)