Skip to content

Commit 004b12c

Browse files
committed
random: introduce State to formalize hooking into rand machinery
1 parent f2fd1f8 commit 004b12c

File tree

5 files changed

+320
-308
lines changed

5 files changed

+320
-308
lines changed

base/random/RNGs.jl

+94-78
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
## RandomDevice
44

5-
const BoolBitIntegerType = Union{Type{Bool},Base.BitIntegerType}
6-
const BoolBitIntegerArray = Union{Array{Bool},Base.BitIntegerArray}
5+
6+
StateTypes(U::Union) = Union{map(T->StateType{T}, Base.uniontypes(U))...}
7+
const StateBoolBitInteger = StateTypes(Union{Bool, Base.BitInteger})
78

89
if Sys.iswindows()
910
struct RandomDevice <: AbstractRNG
@@ -12,15 +13,9 @@ if Sys.iswindows()
1213
RandomDevice() = new(Vector{UInt128}(1))
1314
end
1415

15-
function rand(rd::RandomDevice, T::BoolBitIntegerType)
16+
function rand(rd::RandomDevice, st::StateBoolBitInteger)
1617
rand!(rd, rd.buffer)
17-
@inbounds return rd.buffer[1] % T
18-
end
19-
20-
function rand!(rd::RandomDevice, A::BoolBitIntegerArray)
21-
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
22-
A, sizeof(A))
23-
A
18+
@inbounds return rd.buffer[1] % st[]
2419
end
2520
else # !windows
2621
struct RandomDevice <: AbstractRNG
@@ -31,10 +26,22 @@ else # !windows
3126
new(open(unlimited ? "/dev/urandom" : "/dev/random"), unlimited)
3227
end
3328

34-
rand(rd::RandomDevice, T::BoolBitIntegerType) = read( rd.file, T)
35-
rand!(rd::RandomDevice, A::BoolBitIntegerArray) = read!(rd.file, A)
29+
rand(rd::RandomDevice, st::StateBoolBitInteger) = read( rd.file, st[])
3630
end # os-test
3731

32+
# NOTE: this can't be put in within the if-else block above
33+
for T in (Bool, Base.BitInteger_types...)
34+
if Sys.iswindows()
35+
@eval function rand!(rd::RandomDevice, A::Array{$T}, ::StateType{$T})
36+
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
37+
A, sizeof(A))
38+
A
39+
end
40+
else
41+
@eval rand!(rd::RandomDevice, A::Array{$T}, ::StateType{$T}) = read!(rd.file, A)
42+
end
43+
end
44+
3845
"""
3946
RandomDevice()
4047
@@ -49,7 +56,7 @@ srand(rng::RandomDevice) = rng
4956

5057
### generation of floats
5158

52-
rand(r::RandomDevice, I::FloatInterval) = rand_generic(r, I)
59+
rand(r::RandomDevice, st::StateTrivial{<:FloatInterval}) = rand_generic(r, st[])
5360

5461

5562
## MersenneTwister
@@ -229,30 +236,30 @@ rand_ui23_raw(r::MersenneTwister) = rand_ui52_raw(r)
229236

230237
#### floats
231238

232-
rand(r::MersenneTwister, I::FloatInterval_64) = (reserve_1(r); rand_inbounds(r, I))
239+
rand(r::MersenneTwister, st::StateTrivial{<:FloatInterval_64}) = (reserve_1(r); rand_inbounds(r, st[]))
233240

234-
rand(r::MersenneTwister, I::FloatInterval) = rand_generic(r, I)
241+
rand(r::MersenneTwister, st::StateTrivial{<:FloatInterval}) = rand_generic(r, st[])
235242

236243
#### integers
237244

238245
rand(r::MersenneTwister,
239-
::Type{T}) where {T<:Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}} =
240-
rand_ui52_raw(r) % T
246+
T::StateTypes(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
247+
rand_ui52_raw(r) % T[]
241248

242-
function rand(r::MersenneTwister, ::Type{UInt64})
249+
function rand(r::MersenneTwister, ::StateType{UInt64})
243250
reserve(r, 2)
244251
rand_ui52_raw_inbounds(r) << 32 rand_ui52_raw_inbounds(r)
245252
end
246253

247-
function rand(r::MersenneTwister, ::Type{UInt128})
254+
function rand(r::MersenneTwister, ::StateType{UInt128})
248255
reserve(r, 3)
249256
xor(rand_ui52_raw_inbounds(r) % UInt128 << 96,
250257
rand_ui52_raw_inbounds(r) % UInt128 << 48,
251258
rand_ui52_raw_inbounds(r))
252259
end
253260

254-
rand(r::MersenneTwister, ::Type{Int64}) = reinterpret(Int64, rand(r, UInt64))
255-
rand(r::MersenneTwister, ::Type{Int128}) = reinterpret(Int128, rand(r, UInt128))
261+
rand(r::MersenneTwister, ::StateType{Int64}) = reinterpret(Int64, rand(r, UInt64))
262+
rand(r::MersenneTwister, ::StateType{Int128}) = reinterpret(Int128, rand(r, UInt128))
256263

257264
#### arrays of floats
258265

@@ -278,16 +285,17 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6
278285
A
279286
end
280287

281-
rand!(r::MersenneTwister, A::AbstractArray{Float64}) = rand_AbstractArray_Float64!(r, A)
288+
rand!(r::MersenneTwister, A::AbstractArray{Float64}, I::StateTrivial{<:FloatInterval_64}) =
289+
rand_AbstractArray_Float64!(r, A, length(A), I[])
282290

283291
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) =
284292
dsfmt_fill_array_close_open!(s, A, n)
285293

286294
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::Close1Open2_64) =
287295
dsfmt_fill_array_close1_open2!(s, A, n)
288296

289-
function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
290-
I::FloatInterval_64=CloseOpen())
297+
function _rand!(r::MersenneTwister, A::Array{Float64}, n::Int,
298+
I::FloatInterval_64)
291299
# depending on the alignment of A, the data written by fill_array! may have
292300
# to be left-shifted by up to 15 bytes (cf. unsafe_copy! below) for
293301
# reproducibility purposes;
@@ -317,65 +325,63 @@ function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
317325
A
318326
end
319327

328+
rand!(r::MersenneTwister, A::Array{Float64}, st::StateTrivial{<:FloatInterval_64}) =
329+
_rand!(r, A, length(A), st[])
330+
320331
mask128(u::UInt128, ::Type{Float16}) =
321332
(u & 0x03ff03ff03ff03ff03ff03ff03ff03ff) | 0x3c003c003c003c003c003c003c003c00
322333

323334
mask128(u::UInt128, ::Type{Float32}) =
324335
(u & 0x007fffff007fffff007fffff007fffff) | 0x3f8000003f8000003f8000003f800000
325336

326-
function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}},
327-
::Close1Open2_64)
328-
T = eltype(A)
329-
n = length(A)
330-
n128 = n * sizeof(T) ÷ 16
331-
Base.@gc_preserve A rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
332-
2*n128, Close1Open2())
333-
# FIXME: This code is completely invalid!!!
334-
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
335-
@inbounds for i in 1:n128
336-
u = A128[i]
337-
u ⊻= u << 26
338-
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
339-
# the bit xor, are:
340-
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
341-
# the bits needing to be random are
342-
# [1:10, 17:26, 33:42, 49:58] (for Float16)
343-
# [1:23, 33:55] (for Float32)
344-
# this is obviously satisfied on the 32 low bits side, and on the high side,
345-
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
346-
# (which are discarded on the low side)
347-
# this is similar for the 64 high bits of u
348-
A128[i] = mask128(u, T)
349-
end
350-
for i in 16*n128÷sizeof(T)+1:n
351-
@inbounds A[i] = rand(r, T) + oneunit(T)
337+
for T in (Float16, Float32)
338+
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateTrivial{Close1Open2{$T}})
339+
n = length(A)
340+
n128 = n * sizeof($T) ÷ 16
341+
Base.@gc_preserve A _rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
342+
2*n128, Close1Open2())
343+
# FIXME: This code is completely invalid!!!
344+
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
345+
@inbounds for i in 1:n128
346+
u = A128[i]
347+
u ⊻= u << 26
348+
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
349+
# the bit xor, are:
350+
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
351+
# the bits needing to be random are
352+
# [1:10, 17:26, 33:42, 49:58] (for Float16)
353+
# [1:23, 33:55] (for Float32)
354+
# this is obviously satisfied on the 32 low bits side, and on the high side,
355+
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
356+
# (which are discarded on the low side)
357+
# this is similar for the 64 high bits of u
358+
A128[i] = mask128(u, $T)
359+
end
360+
for i in 16*n128÷sizeof($T)+1:n
361+
@inbounds A[i] = rand(r, $T) + oneunit($T)
362+
end
363+
A
352364
end
353-
A
354-
end
355365

356-
function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}, ::CloseOpen_64)
357-
rand!(r, A, Close1Open2())
358-
I32 = one(Float32)
359-
for i in eachindex(A)
360-
@inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16
366+
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateTrivial{CloseOpen{$T}})
367+
rand!(r, A, Close1Open2($T))
368+
I32 = one(Float32)
369+
for i in eachindex(A)
370+
@inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16
371+
end
372+
A
361373
end
362-
A
363374
end
364375

365-
rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}) =
366-
rand!(r, A, CloseOpen())
367-
368376
#### arrays of integers
369377

370-
function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
371-
if n > length(A)
372-
throw(BoundsError(A,n))
373-
end
378+
function rand!(r::MersenneTwister, A::Array{UInt128}, ::StateType{UInt128})
379+
n::Int=length(A)
374380
# FIXME: This code is completely invalid!!!
375381
Af = unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2n)
376382
i = n
377383
while true
378-
rand!(r, Af, 2i, Close1Open2())
384+
_rand!(r, Af, 2i, Close1Open2())
379385
n < 5 && break
380386
i = 0
381387
@inbounds while n-i >= 5
@@ -396,17 +402,18 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
396402
A
397403
end
398404

399-
# A::Array{UInt128} will match the specialized method above
400-
function rand!(r::MersenneTwister, A::Base.BitIntegerArray)
401-
n = length(A)
402-
T = eltype(A)
403-
n128 = n * sizeof(T) ÷ 16
404-
# FIXME: This code is completely invalid!!!
405-
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
406-
for i = 16*n128÷sizeof(T)+1:n
407-
@inbounds A[i] = rand(r, T)
405+
for T in Base.BitInteger_types
406+
T === UInt128 && continue
407+
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateType{$T})
408+
n = length(A)
409+
n128 = n * sizeof($T) ÷ 16
410+
# FIXME: This code is completely invalid!!!
411+
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
412+
for i = 16*n128÷sizeof($T)+1:n
413+
@inbounds A[i] = rand(r, $T)
414+
end
415+
A
408416
end
409-
A
410417
end
411418

412419
#### from a range
@@ -418,7 +425,9 @@ function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
418425
end
419426
end
420427

421-
function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInteger64,Bool}
428+
function rand(rng::MersenneTwister,
429+
st::StateTrivial{UnitRange{T}}) where T<:Union{Base.BitInteger64,Bool}
430+
r = st[]
422431
isempty(r) && throw(ArgumentError("range must be non-empty"))
423432
m = last(r) % UInt64 - first(r) % UInt64
424433
bw = (64 - leading_zeros(m)) % UInt # bit-width
@@ -428,7 +437,9 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInte
428437
(x + first(r) % UInt64) % T
429438
end
430439

431-
function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt128}
440+
function rand(rng::MersenneTwister,
441+
st::StateTrivial{UnitRange{T}}) where T<:Union{Int128,UInt128}
442+
r = st[]
432443
isempty(r) && throw(ArgumentError("range must be non-empty"))
433444
m = (last(r)-first(r)) % UInt128
434445
bw = (128 - leading_zeros(m)) % UInt # bit-width
@@ -439,6 +450,11 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt1
439450
x % T + first(r)
440451
end
441452

453+
for T in (Bool, Base.BitInteger_types...) # eval because of ambiguity otherwise
454+
@eval State(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
455+
StateTrivial(r)
456+
end
457+
442458

443459
### randjump
444460

0 commit comments

Comments
 (0)