|
2 | 2 |
|
3 | 3 | ## RandomDevice
|
4 | 4 |
|
5 |
| -# SamplerUnion(Union{X,Y,...}) == Union{SamplerType{X},SamplerType{Y},...} |
6 |
| -SamplerUnion(U::Union) = Union{map(T->SamplerType{T}, Base.uniontypes(U))...} |
7 |
| -const SamplerBoolBitInteger = SamplerUnion(Union{Bool, BitInteger}) |
| 5 | +# SamplerUnion(X, Y, ...}) == Union{SamplerType{X}, SamplerType{Y}, ...} |
| 6 | +SamplerUnion(U...) = Union{Any[SamplerType{T} for T in U]...} |
| 7 | +const SamplerBoolBitInteger = SamplerUnion(Bool, BitInteger_types...) |
8 | 8 |
|
9 | 9 | if Sys.iswindows()
|
10 | 10 | struct RandomDevice <: AbstractRNG
|
@@ -285,14 +285,66 @@ function seed!(r::MersenneTwister, seed::Vector{UInt32})
|
285 | 285 | return r
|
286 | 286 | end
|
287 | 287 |
|
288 |
| -seed!(r::MersenneTwister=GLOBAL_RNG) = seed!(r, make_seed()) |
| 288 | +seed!(r::MersenneTwister=get_local_rng()) = seed!(r, make_seed()) |
289 | 289 | seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
|
290 |
| -seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(GLOBAL_RNG, seed) |
| 290 | +seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(get_local_rng(), seed) |
291 | 291 |
|
292 | 292 |
|
293 |
| -### Global RNG (must be defined after seed!) |
| 293 | +### Global RNG |
294 | 294 |
|
295 |
| -const GLOBAL_RNG = MersenneTwister(0) |
| 295 | +const THREAD_RNGs = MersenneTwister[] |
| 296 | +@inline get_local_rng() = get_local_rng(Threads.threadid()) |
| 297 | +@noinline Base.@pure function get_local_rng(tid::Int) |
| 298 | + #tls = task_local_storage() |
| 299 | + #RNG = get(tls, :RNG, nothing) |
| 300 | + #RNG isa MersenneTwister && return RNG |
| 301 | + if length(THREAD_RNGs) < tid |
| 302 | + resize!(THREAD_RNGs, Threads.nthreads()) |
| 303 | + end |
| 304 | + if @inbounds isassigned(THREAD_RNGs, tid) |
| 305 | + @inbounds MT = THREAD_RNGs[tid] |
| 306 | + else |
| 307 | + MT = MersenneTwister() |
| 308 | + @inbounds THREAD_RNGs[tid] = MT |
| 309 | + end |
| 310 | + return MT |
| 311 | +end |
| 312 | +__init__() = empty!(THREAD_RNGs) # ensures that we didn't save a bad object |
| 313 | + |
| 314 | + |
| 315 | +struct _GLOBAL_RNG <: AbstractRNG |
| 316 | + global const GLOBAL_RNG = _GLOBAL_RNG.instance |
| 317 | +end |
| 318 | + |
| 319 | +copy!(dst::MersenneTwister, ::_GLOBAL_RNG) = copy!(dst, get_local_rng()) |
| 320 | +copy!(::_GLOBAL_RNG, src::MersenneTwister) = copy!(get_local_rng(), src) |
| 321 | +copy(::_GLOBAL_RNG) = copy(get_local_rng()) |
| 322 | + |
| 323 | +seed!(::_GLOBAL_RNG, seed::Vector{UInt32}) = seed!(get_local_rng(), seed) |
| 324 | +seed!(::_GLOBAL_RNG, n::Integer) = seed!(get_local_rng(), n) |
| 325 | +seed!(::_GLOBAL_RNG, ::Nothing) = seed!(get_local_rng(), nothing) |
| 326 | + |
| 327 | +rng_native_52(::_GLOBAL_RNG) = rng_native_52(get_local_rng()) |
| 328 | +rand(::_GLOBAL_RNG, sp::SamplerBoolBitInteger) = rand(get_local_rng(), sp) |
| 329 | +for T in (:(SamplerTrivial{UInt52Raw{UInt64}}), |
| 330 | + :(SamplerTrivial{UInt2x52Raw{UInt128}}), |
| 331 | + :(SamplerTrivial{UInt104Raw{UInt128}}), |
| 332 | + :(SamplerTrivial{CloseOpen12_64}), |
| 333 | + :(SamplerUnion(Int64, UInt64, Int128, UInt128)), |
| 334 | + :(SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)), |
| 335 | + ) |
| 336 | + @eval rand(::_GLOBAL_RNG, x::$T) = rand(get_local_rng(), x) |
| 337 | +end |
| 338 | + |
| 339 | +rand!(::_GLOBAL_RNG, A::AbstractArray{Float64}, I::SamplerTrivial{<:FloatInterval_64}) = rand!(get_local_rng(), A, I) |
| 340 | +rand!(::_GLOBAL_RNG, A::Array{Float64}, I::SamplerTrivial{<:FloatInterval_64}) = rand!(get_local_rng(), A, I) |
| 341 | +for T in (Float16, Float32) |
| 342 | + @eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerTrivial{CloseOpen12{$T}}) = rand!(get_local_rng(), A, I) |
| 343 | + @eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerTrivial{CloseOpen01{$T}}) = rand!(get_local_rng(), A, I) |
| 344 | +end |
| 345 | +for T in BitInteger_types |
| 346 | + @eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerType{$T}) = rand!(get_local_rng(), A, I) |
| 347 | +end |
296 | 348 |
|
297 | 349 |
|
298 | 350 | ### generation
|
@@ -332,10 +384,10 @@ rand(r::MersenneTwister, sp::SamplerTrivial{CloseOpen12_64}) =
|
332 | 384 |
|
333 | 385 | #### integers
|
334 | 386 |
|
335 |
| -rand(r::MersenneTwister, T::SamplerUnion(Union{Int64,UInt64,Int128,UInt128})) = |
| 387 | +rand(r::MersenneTwister, T::SamplerUnion(Int64, UInt64, Int128, UInt128)) = |
336 | 388 | mt_pop!(r, T[])
|
337 | 389 |
|
338 |
| -rand(r::MersenneTwister, T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) = |
| 390 | +rand(r::MersenneTwister, T::SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)) = |
339 | 391 | rand(r, UInt52Raw()) % T[]
|
340 | 392 |
|
341 | 393 | #### arrays of floats
|
|
0 commit comments