Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MersenneTwister: more efficient integer generation with caching #25058

Merged
merged 1 commit into from
Dec 23, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions base/int.jl
Original file line number Diff line number Diff line change
@@ -7,25 +7,38 @@
# they are also used elsewhere where Int128/UInt128 support is separated out,
# such as in hashing2.jl

const BitSigned64_types = (Int8, Int16, Int32, Int64)
const BitUnsigned64_types = (UInt8, UInt16, UInt32, UInt64)
const BitSigned32_types = (Int8, Int16, Int32)
const BitUnsigned32_types = (UInt8, UInt16, UInt32)
const BitInteger32_types = (BitSigned32_types..., BitUnsigned32_types...)

const BitSigned64_types = (BitSigned32_types..., Int64)
const BitUnsigned64_types = (BitUnsigned32_types..., UInt64)
const BitInteger64_types = (BitSigned64_types..., BitUnsigned64_types...)

const BitSigned_types = (BitSigned64_types..., Int128)
const BitUnsigned_types = (BitUnsigned64_types..., UInt128)
const BitInteger_types = (BitSigned_types..., BitUnsigned_types...)

const BitSignedSmall_types = Int === Int64 ? ( Int8, Int16, Int32) : ( Int8, Int16)
const BitUnsignedSmall_types = Int === Int64 ? (UInt8, UInt16, UInt32) : (UInt8, UInt16)
const BitIntegerSmall_types = (BitSignedSmall_types..., BitUnsignedSmall_types...)

const BitSigned32 = Union{BitSigned32_types...}
const BitUnsigned32 = Union{BitUnsigned32_types...}
const BitInteger32 = Union{BitInteger32_types...}

const BitSigned64 = Union{BitSigned64_types...}
const BitUnsigned64 = Union{BitUnsigned64_types...}
const BitInteger64 = Union{BitInteger64_types...}

const BitSigned = Union{BitSigned_types...}
const BitUnsigned = Union{BitUnsigned_types...}
const BitInteger = Union{BitInteger_types...}

const BitSignedSmall = Union{BitSignedSmall_types...}
const BitUnsignedSmall = Union{BitUnsignedSmall_types...}
const BitIntegerSmall = Union{BitIntegerSmall_types...}

const BitSigned64T = Union{Type{Int8}, Type{Int16}, Type{Int32}, Type{Int64}}
const BitUnsigned64T = Union{Type{UInt8}, Type{UInt16}, Type{UInt32}, Type{UInt64}}

136 changes: 94 additions & 42 deletions base/random/RNGs.jl
Original file line number Diff line number Diff line change
@@ -60,24 +60,33 @@ srand(rng::RandomDevice) = rng

## MersenneTwister

const MTCacheLength = dsfmt_get_min_array_size()
const MT_CACHE_F = dsfmt_get_min_array_size()
const MT_CACHE_I = 501 << 4

mutable struct MersenneTwister <: AbstractRNG
seed::Vector{UInt32}
state::DSFMT_state
vals::Vector{Float64}
idx::Int

function MersenneTwister(seed, state, vals, idx)
length(vals) == MTCacheLength && 0 <= idx <= MTCacheLength ||
throw(DomainError((length(vals), idx),
"`length(vals)` and `idx` must be consistent with $MTCacheLength"))
new(seed, state, vals, idx)
ints::Vector{UInt128}
idxF::Int
idxI::Int

function MersenneTwister(seed, state, vals, ints, idxF, idxI)
length(vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F ||
throw(DomainError((length(vals), idxF),
"`length(vals)` and `idxF` must be consistent with $MT_CACHE_F"))
length(ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I ||
throw(DomainError((length(ints), idxI),
"`length(ints)` and `idxI` must be consistent with $MT_CACHE_I"))
new(seed, state, vals, ints, idxF, idxI)
end
end

MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength)
MersenneTwister(seed, state,
Vector{Float64}(uninitialized, MT_CACHE_F),
Vector{UInt128}(uninitialized, MT_CACHE_I >> 4),
MT_CACHE_F, 0)

"""
MersenneTwister(seed)
@@ -120,27 +129,36 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
copyto!(resize!(dst.seed, length(src.seed)), src.seed)
copy!(dst.state, src.state)
copyto!(dst.vals, src.vals)
dst.idx = src.idx
copyto!(dst.ints, src.ints)
dst.idxF = src.idxF
dst.idxI = src.idxI
dst
end

copy(src::MersenneTwister) =
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), src.idx)
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints),
src.idxF, src.idxI)


==(r1::MersenneTwister, r2::MersenneTwister) =
r1.seed == r2.seed && r1.state == r2.state && isequal(r1.vals, r2.vals) &&
r1.idx == r2.idx
r1.seed == r2.seed && r1.state == r2.state &&
isequal(r1.vals, r2.vals) &&
isequal(r1.ints, r2.ints) &&
r1.idxF == r2.idxF && r1.idxI == r2.idxI

hash(r::MersenneTwister, h::UInt) = foldr(hash, h, (r.seed, r.state, r.vals, r.idx))
hash(r::MersenneTwister, h::UInt) =
foldr(hash, h, (r.seed, r.state, r.vals, r.ints, r.idxF, r.idxI))


### low level API

mt_avail(r::MersenneTwister) = MTCacheLength - r.idx
mt_empty(r::MersenneTwister) = r.idx == MTCacheLength
mt_setfull!(r::MersenneTwister) = r.idx = 0
mt_setempty!(r::MersenneTwister) = r.idx = MTCacheLength
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idx+=1]
#### floats

mt_avail(r::MersenneTwister) = MT_CACHE_F - r.idxF
mt_empty(r::MersenneTwister) = r.idxF == MT_CACHE_F
mt_setfull!(r::MersenneTwister) = r.idxF = 0
mt_setempty!(r::MersenneTwister) = r.idxF = MT_CACHE_F
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idxF+=1]

function gen_rand(r::MersenneTwister)
@gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals))
@@ -149,9 +167,56 @@ end

reserve_1(r::MersenneTwister) = (mt_empty(r) && gen_rand(r); nothing)
# `reserve` allows one to call `rand_inbounds` n times
# precondition: n <= MTCacheLength
# precondition: n <= MT_CACHE_F
reserve(r::MersenneTwister, n::Int) = (mt_avail(r) < n && gen_rand(r); nothing)

#### ints

logsizeof(::Type{<:Union{Bool,Int8,UInt8}}) = 0
logsizeof(::Type{<:Union{Int16,UInt16}}) = 1
logsizeof(::Type{<:Union{Int32,UInt32}}) = 2
logsizeof(::Type{<:Union{Int64,UInt64}}) = 3
logsizeof(::Type{<:Union{Int128,UInt128}}) = 4

idxmask(::Type{<:Union{Bool,Int8,UInt8}}) = 15
idxmask(::Type{<:Union{Int16,UInt16}}) = 7
idxmask(::Type{<:Union{Int32,UInt32}}) = 3
idxmask(::Type{<:Union{Int64,UInt64}}) = 1
idxmask(::Type{<:Union{Int128,UInt128}}) = 0


mt_avail(r::MersenneTwister, ::Type{T}) where {T<:BitInteger} =
r.idxI >> logsizeof(T)

function mt_setfull!(r::MersenneTwister, ::Type{<:BitInteger})
rand!(r, r.ints)
r.idxI = MT_CACHE_I
end

mt_setempty!(r::MersenneTwister, ::Type{<:BitInteger}) = r.idxI = 0

function reserve1(r::MersenneTwister, ::Type{T}) where T<:BitInteger
r.idxI < sizeof(T) && mt_setfull!(r, T)
nothing
end

function mt_pop!(r::MersenneTwister, ::Type{T}) where T<:BitInteger
reserve1(r, T)
r.idxI -= sizeof(T)
i = r.idxI
@inbounds x128 = r.ints[1 + i >> 4]
i128 = (i >> logsizeof(T)) & idxmask(T) # 0-based "indice" in x128
(x128 >> (i128 * (sizeof(T) << 3))) % T
end

# not necessary, but very slightly more efficient
function mt_pop!(r::MersenneTwister, ::Type{T}) where {T<:Union{Int128,UInt128}}
reserve1(r, T)
@inbounds res = r.ints[r.idxI >> 4]
r.idxI -= 16
res % T
end


### seeding

@@ -193,6 +258,9 @@ function srand(r::MersenneTwister, seed::Vector{UInt32})
copyto!(resize!(r.seed, length(seed)), seed)
dsfmt_init_by_array(r.state, r.seed)
mt_setempty!(r)
fill!(r.vals, 0.0) # not strictly necessary, but why not, makes comparing two MT easier
mt_setempty!(r, UInt128)
fill!(r.ints, 0)
return r
end

@@ -243,24 +311,8 @@ rand(r::MersenneTwister, sp::SamplerTrivial{Close1Open2_64}) =

#### integers

rand(r::MersenneTwister,
T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
rand(r, UInt52Raw()) % T[]

function rand(r::MersenneTwister, ::SamplerType{UInt64})
reserve(r, 2)
rand_inbounds(r, UInt52Raw()) << 32 rand_inbounds(r, UInt52Raw())
end

function rand(r::MersenneTwister, ::SamplerType{UInt128})
reserve(r, 3)
xor(rand_inbounds(r, UInt52Raw(UInt128)) << 96,
rand_inbounds(r, UInt52Raw(UInt128)) << 48,
rand_inbounds(r, UInt52Raw(UInt128)))
end

rand(r::MersenneTwister, ::SamplerType{Int64}) = rand(r, UInt64) % Int64
rand(r::MersenneTwister, ::SamplerType{Int128}) = rand(r, UInt128) % Int128
rand(r::MersenneTwister, T::SamplerUnion(BitInteger)) = mt_pop!(r, T[])
rand(r::MersenneTwister, ::SamplerType{Bool}) = rand(r, UInt8) % Bool

#### arrays of floats

@@ -315,13 +367,13 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter
mt_avail(r) == 0 && gen_rand(r)
# from now on, at most one call to gen_rand(r) will be necessary
m = min(n, mt_avail(r))
@gc_preserve r unsafe_copyto!(A.ptr, pointer(r.vals, r.idx+1), m)
@gc_preserve r unsafe_copyto!(A.ptr, pointer(r.vals, r.idxF+1), m)
if m == n
r.idx += m
r.idxF += m
else # m < n
gen_rand(r)
@gc_preserve r unsafe_copyto!(A.ptr+m*sizeof(Float64), pointer(r.vals), n-m)
r.idx = n-m
r.idxF = n-m
end
if I isa CloseOpen
for i=1:n
@@ -470,7 +522,7 @@ end

#### from a range

for T in (Bool, BitInteger_types...) # eval because of ambiguity otherwise
for T in BitInteger_types # eval because of ambiguity otherwise
@eval Sampler(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
SamplerRangeFast(r)
end
54 changes: 41 additions & 13 deletions base/random/generation.jl
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
# Note that the 1) is automated when the sampler is not intended to carry information,
# i.e. the default fall-backs SamplerType and SamplerTrivial are used.


## from types: rand(::Type, [dims...])

### random floats
@@ -101,6 +100,8 @@ rand(rng::AbstractRNG, sp::SamplerBigFloat{T}) where {T<:FloatInterval{BigFloat}

### random integers

#### UniformBits

rand(r::AbstractRNG, ::SamplerTrivial{UInt10Raw{UInt16}}) = rand(r, UInt16)
rand(r::AbstractRNG, ::SamplerTrivial{UInt23Raw{UInt32}}) = rand(r, UInt32)

@@ -111,7 +112,7 @@ _rand52(r::AbstractRNG, ::Type{Float64}) = reinterpret(UInt64, rand(r, Close1Ope
_rand52(r::AbstractRNG, ::Type{UInt64}) = rand(r, UInt64)

rand(r::AbstractRNG, ::SamplerTrivial{UInt104Raw{UInt128}}) =
rand(r, UInt52Raw(UInt128)) << 52 rand_inbounds(r, UInt52Raw(UInt128))
rand(r, UInt52Raw(UInt128)) << 52 rand(r, UInt52Raw(UInt128))

rand(r::AbstractRNG, ::SamplerTrivial{UInt10{UInt16}}) = rand(r, UInt10Raw()) & 0x03ff
rand(r::AbstractRNG, ::SamplerTrivial{UInt23{UInt32}}) = rand(r, UInt23Raw()) & 0x007fffff
@@ -121,6 +122,32 @@ rand(r::AbstractRNG, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())
rand(r::AbstractRNG, sp::SamplerTrivial{<:UniformBits{T}}) where {T} =
rand(r, uint_default(sp[])) % T

#### BitInteger

# rand_generic methods are intended to help RNG implementors with common operations
# we don't call them simply `rand` as this can easily contribute to create
# amibuities with user-side methods (forcing the user to resort to @eval)

rand_generic(r::AbstractRNG, T::Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}) =
rand(r, UInt52Raw()) % T[]

rand_generic(r::AbstractRNG, ::Type{UInt64}) =
rand(r, UInt52Raw()) << 32 rand(r, UInt52Raw())

rand_generic(r::AbstractRNG, ::Type{UInt128}) = _rand128(r, rng_native_52(r))

_rand128(r::AbstractRNG, ::Type{UInt64}) =
((rand(r, UInt64) % UInt128) << 64) rand(r, UInt64)

function _rand128(r::AbstractRNG, ::Type{Float64})
xor(rand(r, UInt52Raw(UInt128)) << 96,
rand(r, UInt52Raw(UInt128)) << 48,
rand(r, UInt52Raw(UInt128)))
end

rand_generic(r::AbstractRNG, ::Type{Int128}) = rand(r, UInt128) % Int128
rand_generic(r::AbstractRNG, ::Type{Int64}) = rand(r, UInt64) % Int64

### random complex numbers

rand(r::AbstractRNG, ::SamplerType{Complex{T}}) where {T<:Real} =
@@ -149,33 +176,34 @@ end

#### helper functions

uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
uint_sup(::Type{<:Base.BitInteger32}) = UInt32
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128

#### Fast

struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
struct SamplerRangeFast{U<:BitUnsigned,T<:BitInteger} <: Sampler
a::T # first element of the range
bw::UInt # bit width
m::U # range length - 1
mask::U # mask generated values before threshold rejection
end

SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
SamplerRangeFast(r::AbstractUnitRange{T}) where T<:BitInteger =
SamplerRangeFast(r, uint_sup(T))

function SamplerRangeFast(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = (last(r) - first(r)) % U
m = (last(r) - first(r)) % unsigned(T) % U # % unsigned(T) to not propagate sign bit
bw = (sizeof(U) << 3 - leading_zeros(m)) % UInt # bit-width
mask = (1 % U << bw) - (1 % U)
SamplerRangeFast{U,T}(first(r), bw, m, mask)
end

function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt32,T}) where T
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
x = rand(rng, LessThan(m, Masked(mask, uniform(UInt32))))
# below, we don't use UInt32, to get reproducible values, whether Int is Int64 or Int32
x = rand(rng, LessThan(m, Masked(mask, UInt52Raw(UInt32))))
(x + a % UInt32) % T
end

@@ -215,21 +243,21 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
div(sup, k + (k == 0))*k - one(k)

struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
struct SamplerRangeInt{T<:Integer,U<:Unsigned} <: Sampler
a::T # first element of the range
bw::Int # bit width
k::U # range length or zero for full range
u::U # rejection threshold
end


SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
SamplerRangeInt(r::AbstractUnitRange{T}) where T<:BitInteger =
SamplerRangeInt(r, uint_sup(T))

function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
isempty(r) && throw(ArgumentError("range must be non-empty"))
a = first(r)
m = (last(r) - first(r)) % U
m = (last(r) - first(r)) % unsigned(T) % U
k = m + one(U)
bw = (sizeof(U) << 3 - leading_zeros(m)) % Int
mult = if U === UInt32
@@ -247,11 +275,11 @@ function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
end

Sampler(::AbstractRNG, r::AbstractUnitRange{T},
::Repetition) where {T<:Union{Bool,BitInteger}} = SamplerRangeInt(r)
::Repetition) where {T<:BitInteger} = SamplerRangeInt(r)

rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:Union{Bool,BitInteger}} =
(unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, uniform(UInt32))), sp.k)) % T

rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:BitInteger} =
(unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, UInt52Raw(UInt32))), sp.k)) % T

# this function uses 52 bit entropy for small ranges of length <= 2^52
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt64}) where T<:BitInteger
25 changes: 16 additions & 9 deletions test/random.jl
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
isdefined(Main, :TestHelpers) || @eval Main include(joinpath(dirname(@__FILE__), "TestHelpers.jl"))
using Main.TestHelpers.OAs

using Base.Random: Sampler, SamplerRangeFast, SamplerRangeInt, dSFMT
using Base.Random: Sampler, SamplerRangeFast, SamplerRangeInt, dSFMT, MT_CACHE_F, MT_CACHE_I

@testset "Issue #6573" begin
srand(0)
@@ -38,7 +38,7 @@ let A = zeros(2, 2)
end
let A = zeros(2, 2)
@test_throws ArgumentError rand!(MersenneTwister(0), A, 5)
@test rand(MersenneTwister(0), Int64, 1) == [4439861565447045202]
@test rand(MersenneTwister(0), Int64, 1) == [5986602421100169002]
end
let A = zeros(Int64, 2, 2)
rand!(MersenneTwister(0), A)
@@ -278,11 +278,10 @@ let mt = MersenneTwister(0)
B = Vector{T}(uninitialized, 31)
rand!(mt, A)
rand!(mt, B)
@test A[end] == Any[21,0x7b,17385,0x3086,-1574090021,0xadcb4460,6797283068698303107,0x4e91c9c4d4f5f759,
-3482609696641744459568613291754091152,Float16(0.03125),0.68733835f0][i]

@test B[end] == Any[49,0x65,-3725,0x719d,814246081,0xdf61843a,-3010919637398300844,0x61b367cf8810985d,
-33032345278809823492812856023466859769,Float16(0.95),0.51829386f0][i]
@test A[end] == Any[21, 0x4e, -3158, 0x0ded, 2132370312, 0x5e76d222, 1701112237820550475, 0xde7c8e58fb113739,
-17260403799139981754163727590537874047, Float16(0.90234), 0.0909704f0][i]
@test B[end] == Any[94, 0xb8, 3111, 0xefa4, 411531475, 0xd8089c1d, -7344871485543005232, 0xedb4b5c61c037a43,
-118178167582054157562031602894265066400, Float16(0.91211), 0.2516626f0][i]
end

srand(mt, 0)
@@ -561,10 +560,18 @@ end

# MersenneTwister initialization with invalid values
@test_throws DomainError Base.dSFMT.DSFMT_state(zeros(Int32, rand(0:Base.dSFMT.JN32-1)))

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(),
zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(),
zeros(Float64, 10), 0)
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(),
zeros(Float64, Base.Random.MTCacheLength), -1)
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1)

# seed is private to MersenneTwister
let seed = rand(UInt32, 10)