Skip to content

Commit 3b9db7c

Browse files
committed
random: introduce Sampler to formalize hooking into rand machinery
[ci skip] [av skip] CI already ran.
1 parent 655e124 commit 3b9db7c

File tree

5 files changed

+341
-308
lines changed

5 files changed

+341
-308
lines changed

base/random/RNGs.jl

+96-79
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
## RandomDevice
44

5-
const BoolBitIntegerType = Union{Type{Bool},Base.BitIntegerType}
6-
const BoolBitIntegerArray = Union{Array{Bool},Base.BitIntegerArray}
5+
# SamplerUnion(Union{X,Y,...}) == Union{SamplerType{X},SamplerType{Y},...}
6+
SamplerUnion(U::Union) = Union{map(T->SamplerType{T}, Base.uniontypes(U))...}
7+
const SamplerBoolBitInteger = SamplerUnion(Union{Bool, Base.BitInteger})
78

89
if Sys.iswindows()
910
struct RandomDevice <: AbstractRNG
@@ -12,15 +13,9 @@ if Sys.iswindows()
1213
RandomDevice() = new(Vector{UInt128}(uninitialized, 1))
1314
end
1415

15-
function rand(rd::RandomDevice, T::BoolBitIntegerType)
16+
function rand(rd::RandomDevice, sp::SamplerBoolBitInteger)
1617
rand!(rd, rd.buffer)
17-
@inbounds return rd.buffer[1] % T
18-
end
19-
20-
function rand!(rd::RandomDevice, A::BoolBitIntegerArray)
21-
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
22-
A, sizeof(A))
23-
A
18+
@inbounds return rd.buffer[1] % sp[]
2419
end
2520
else # !windows
2621
struct RandomDevice <: AbstractRNG
@@ -31,10 +26,22 @@ else # !windows
3126
new(open(unlimited ? "/dev/urandom" : "/dev/random"), unlimited)
3227
end
3328

34-
rand(rd::RandomDevice, T::BoolBitIntegerType) = read( rd.file, T)
35-
rand!(rd::RandomDevice, A::BoolBitIntegerArray) = read!(rd.file, A)
29+
rand(rd::RandomDevice, sp::SamplerBoolBitInteger) = read( rd.file, sp[])
3630
end # os-test
3731

32+
# NOTE: this can't be put within the if-else block above
33+
for T in (Bool, Base.BitInteger_types...)
34+
if Sys.iswindows()
35+
@eval function rand!(rd::RandomDevice, A::Array{$T}, ::SamplerType{$T})
36+
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
37+
A, sizeof(A))
38+
A
39+
end
40+
else
41+
@eval rand!(rd::RandomDevice, A::Array{$T}, ::SamplerType{$T}) = read!(rd.file, A)
42+
end
43+
end
44+
3845
"""
3946
RandomDevice()
4047
@@ -49,7 +56,7 @@ srand(rng::RandomDevice) = rng
4956

5057
### generation of floats
5158

52-
rand(r::RandomDevice, I::FloatInterval) = rand_generic(r, I)
59+
rand(r::RandomDevice, sp::SamplerTrivial{<:FloatInterval}) = rand_generic(r, sp[])
5360

5461

5562
## MersenneTwister
@@ -229,30 +236,31 @@ rand_ui23_raw(r::MersenneTwister) = rand_ui52_raw(r)
229236

230237
#### floats
231238

232-
rand(r::MersenneTwister, I::FloatInterval_64) = (reserve_1(r); rand_inbounds(r, I))
239+
rand(r::MersenneTwister, sp::SamplerTrivial{<:FloatInterval_64}) =
240+
(reserve_1(r); rand_inbounds(r, sp[]))
233241

234-
rand(r::MersenneTwister, I::FloatInterval) = rand_generic(r, I)
242+
rand(r::MersenneTwister, sp::SamplerTrivial{<:FloatInterval}) = rand_generic(r, sp[])
235243

236244
#### integers
237245

238-
rand(r::MersenneTwister, T::Union{Type{Bool}, Type{Int8}, Type{UInt8}, Type{Int16}, Type{UInt16},
239-
Type{Int32}, Type{UInt32}}) =
240-
rand_ui52_raw(r) % T
246+
rand(r::MersenneTwister,
247+
T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
248+
rand_ui52_raw(r) % T[]
241249

242-
function rand(r::MersenneTwister, ::Type{UInt64})
250+
function rand(r::MersenneTwister, ::SamplerType{UInt64})
243251
reserve(r, 2)
244252
rand_ui52_raw_inbounds(r) << 32 rand_ui52_raw_inbounds(r)
245253
end
246254

247-
function rand(r::MersenneTwister, ::Type{UInt128})
255+
function rand(r::MersenneTwister, ::SamplerType{UInt128})
248256
reserve(r, 3)
249257
xor(rand_ui52_raw_inbounds(r) % UInt128 << 96,
250258
rand_ui52_raw_inbounds(r) % UInt128 << 48,
251259
rand_ui52_raw_inbounds(r))
252260
end
253261

254-
rand(r::MersenneTwister, ::Type{Int64}) = reinterpret(Int64, rand(r, UInt64))
255-
rand(r::MersenneTwister, ::Type{Int128}) = reinterpret(Int128, rand(r, UInt128))
262+
rand(r::MersenneTwister, ::SamplerType{Int64}) = reinterpret(Int64, rand(r, UInt64))
263+
rand(r::MersenneTwister, ::SamplerType{Int128}) = reinterpret(Int128, rand(r, UInt128))
256264

257265
#### arrays of floats
258266

@@ -278,16 +286,17 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6
278286
A
279287
end
280288

281-
rand!(r::MersenneTwister, A::AbstractArray{Float64}) = rand_AbstractArray_Float64!(r, A)
289+
rand!(r::MersenneTwister, A::AbstractArray{Float64}, I::SamplerTrivial{<:FloatInterval_64}) =
290+
rand_AbstractArray_Float64!(r, A, length(A), I[])
282291

283292
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) =
284293
dsfmt_fill_array_close_open!(s, A, n)
285294

286295
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::Close1Open2_64) =
287296
dsfmt_fill_array_close1_open2!(s, A, n)
288297

289-
function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
290-
I::FloatInterval_64=CloseOpen())
298+
function _rand!(r::MersenneTwister, A::Array{Float64}, n::Int,
299+
I::FloatInterval_64)
291300
# depending on the alignment of A, the data written by fill_array! may have
292301
# to be left-shifted by up to 15 bytes (cf. unsafe_copy! below) for
293302
# reproducibility purposes;
@@ -317,65 +326,63 @@ function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
317326
A
318327
end
319328

329+
rand!(r::MersenneTwister, A::Array{Float64}, sp::SamplerTrivial{<:FloatInterval_64}) =
330+
_rand!(r, A, length(A), sp[])
331+
320332
mask128(u::UInt128, ::Type{Float16}) =
321333
(u & 0x03ff03ff03ff03ff03ff03ff03ff03ff) | 0x3c003c003c003c003c003c003c003c00
322334

323335
mask128(u::UInt128, ::Type{Float32}) =
324336
(u & 0x007fffff007fffff007fffff007fffff) | 0x3f8000003f8000003f8000003f800000
325337

326-
function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}},
327-
::Close1Open2_64)
328-
T = eltype(A)
329-
n = length(A)
330-
n128 = n * sizeof(T) ÷ 16
331-
Base.@gc_preserve A rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
332-
2*n128, Close1Open2())
333-
# FIXME: This code is completely invalid!!!
334-
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
335-
@inbounds for i in 1:n128
336-
u = A128[i]
337-
u ⊻= u << 26
338-
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
339-
# the bit xor, are:
340-
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
341-
# the bits needing to be random are
342-
# [1:10, 17:26, 33:42, 49:58] (for Float16)
343-
# [1:23, 33:55] (for Float32)
344-
# this is obviously satisfied on the 32 low bits side, and on the high side,
345-
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
346-
# (which are discarded on the low side)
347-
# this is similar for the 64 high bits of u
348-
A128[i] = mask128(u, T)
349-
end
350-
for i in 16*n128÷sizeof(T)+1:n
351-
@inbounds A[i] = rand(r, T) + oneunit(T)
338+
for T in (Float16, Float32)
339+
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{Close1Open2{$T}})
340+
n = length(A)
341+
n128 = n * sizeof($T) ÷ 16
342+
Base.@gc_preserve A _rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
343+
2*n128, Close1Open2())
344+
# FIXME: This code is completely invalid!!!
345+
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
346+
@inbounds for i in 1:n128
347+
u = A128[i]
348+
u ⊻= u << 26
349+
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
350+
# the bit xor, are:
351+
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
352+
# the bits needing to be random are
353+
# [1:10, 17:26, 33:42, 49:58] (for Float16)
354+
# [1:23, 33:55] (for Float32)
355+
# this is obviously satisfied on the 32 low bits side, and on the high side,
356+
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
357+
# (which are discarded on the low side)
358+
# this is similar for the 64 high bits of u
359+
A128[i] = mask128(u, $T)
360+
end
361+
for i in 16*n128÷sizeof($T)+1:n
362+
@inbounds A[i] = rand(r, $T) + oneunit($T)
363+
end
364+
A
352365
end
353-
A
354-
end
355366

356-
function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}, ::CloseOpen_64)
357-
rand!(r, A, Close1Open2())
358-
I32 = one(Float32)
359-
for i in eachindex(A)
360-
@inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16
367+
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{CloseOpen{$T}})
368+
rand!(r, A, Close1Open2($T))
369+
I32 = one(Float32)
370+
for i in eachindex(A)
371+
@inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16
372+
end
373+
A
361374
end
362-
A
363375
end
364376

365-
rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}) =
366-
rand!(r, A, CloseOpen())
367-
368377
#### arrays of integers
369378

370-
function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
371-
if n > length(A)
372-
throw(BoundsError(A,n))
373-
end
379+
function rand!(r::MersenneTwister, A::Array{UInt128}, ::SamplerType{UInt128})
380+
n::Int=length(A)
374381
# FIXME: This code is completely invalid!!!
375382
Af = unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2n)
376383
i = n
377384
while true
378-
rand!(r, Af, 2i, Close1Open2())
385+
_rand!(r, Af, 2i, Close1Open2())
379386
n < 5 && break
380387
i = 0
381388
@inbounds while n-i >= 5
@@ -396,17 +403,18 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
396403
A
397404
end
398405

399-
# A::Array{UInt128} will match the specialized method above
400-
function rand!(r::MersenneTwister, A::Base.BitIntegerArray)
401-
n = length(A)
402-
T = eltype(A)
403-
n128 = n * sizeof(T) ÷ 16
404-
# FIXME: This code is completely invalid!!!
405-
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
406-
for i = 16*n128÷sizeof(T)+1:n
407-
@inbounds A[i] = rand(r, T)
406+
for T in Base.BitInteger_types
407+
T === UInt128 && continue
408+
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerType{$T})
409+
n = length(A)
410+
n128 = n * sizeof($T) ÷ 16
411+
# FIXME: This code is completely invalid!!!
412+
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
413+
for i = 16*n128÷sizeof($T)+1:n
414+
@inbounds A[i] = rand(r, $T)
415+
end
416+
A
408417
end
409-
A
410418
end
411419

412420
#### from a range
@@ -418,7 +426,9 @@ function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
418426
end
419427
end
420428

421-
function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInteger64,Bool}
429+
function rand(rng::MersenneTwister,
430+
sp::SamplerTrivial{UnitRange{T}}) where T<:Union{Base.BitInteger64,Bool}
431+
r = sp[]
422432
isempty(r) && throw(ArgumentError("range must be non-empty"))
423433
m = last(r) % UInt64 - first(r) % UInt64
424434
bw = (64 - leading_zeros(m)) % UInt # bit-width
@@ -428,7 +438,9 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInte
428438
(x + first(r) % UInt64) % T
429439
end
430440

431-
function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt128}
441+
function rand(rng::MersenneTwister,
442+
sp::SamplerTrivial{UnitRange{T}}) where T<:Union{Int128,UInt128}
443+
r = sp[]
432444
isempty(r) && throw(ArgumentError("range must be non-empty"))
433445
m = (last(r)-first(r)) % UInt128
434446
bw = (128 - leading_zeros(m)) % UInt # bit-width
@@ -439,6 +451,11 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt1
439451
x % T + first(r)
440452
end
441453

454+
for T in (Bool, Base.BitInteger_types...) # eval because of ambiguity otherwise
455+
@eval Sampler(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
456+
SamplerTrivial(r)
457+
end
458+
442459

443460
### randjump
444461

0 commit comments

Comments
 (0)