Skip to content

Commit d32e267

Browse files
committed
MersenneTwister: more efficient integer generation with caching
1 parent 7b1c06a commit d32e267

File tree

4 files changed

+165
-66
lines changed

4 files changed

+165
-66
lines changed

base/int.jl

+15-2
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,38 @@
77
# they are also used elsewhere where Int128/UInt128 support is separated out,
88
# such as in hashing2.jl
99

10-
const BitSigned64_types = (Int8, Int16, Int32, Int64)
11-
const BitUnsigned64_types = (UInt8, UInt16, UInt32, UInt64)
10+
const BitSigned32_types = (Int8, Int16, Int32)
11+
const BitUnsigned32_types = (UInt8, UInt16, UInt32)
12+
const BitInteger32_types = (BitSigned32_types..., BitUnsigned32_types...)
13+
14+
const BitSigned64_types = (BitSigned32_types..., Int64)
15+
const BitUnsigned64_types = (BitUnsigned32_types..., UInt64)
1216
const BitInteger64_types = (BitSigned64_types..., BitUnsigned64_types...)
17+
1318
const BitSigned_types = (BitSigned64_types..., Int128)
1419
const BitUnsigned_types = (BitUnsigned64_types..., UInt128)
1520
const BitInteger_types = (BitSigned_types..., BitUnsigned_types...)
21+
1622
const BitSignedSmall_types = Int === Int64 ? ( Int8, Int16, Int32) : ( Int8, Int16)
1723
const BitUnsignedSmall_types = Int === Int64 ? (UInt8, UInt16, UInt32) : (UInt8, UInt16)
1824
const BitIntegerSmall_types = (BitSignedSmall_types..., BitUnsignedSmall_types...)
1925

26+
const BitSigned32 = Union{BitSigned32_types...}
27+
const BitUnsigned32 = Union{BitUnsigned32_types...}
28+
const BitInteger32 = Union{BitInteger32_types...}
29+
2030
const BitSigned64 = Union{BitSigned64_types...}
2131
const BitUnsigned64 = Union{BitUnsigned64_types...}
2232
const BitInteger64 = Union{BitInteger64_types...}
33+
2334
const BitSigned = Union{BitSigned_types...}
2435
const BitUnsigned = Union{BitUnsigned_types...}
2536
const BitInteger = Union{BitInteger_types...}
37+
2638
const BitSignedSmall = Union{BitSignedSmall_types...}
2739
const BitUnsignedSmall = Union{BitUnsignedSmall_types...}
2840
const BitIntegerSmall = Union{BitIntegerSmall_types...}
41+
2942
const BitSigned64T = Union{Type{Int8}, Type{Int16}, Type{Int32}, Type{Int64}}
3043
const BitUnsigned64T = Union{Type{UInt8}, Type{UInt16}, Type{UInt32}, Type{UInt64}}
3144

base/random/RNGs.jl

+93-42
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,33 @@ srand(rng::RandomDevice) = rng
6060

6161
## MersenneTwister
6262

63-
const MTCacheLength = dsfmt_get_min_array_size()
63+
const MT_CACHE_F = dsfmt_get_min_array_size()
64+
const MT_CACHE_I = 501 << 4
6465

6566
mutable struct MersenneTwister <: AbstractRNG
6667
seed::Vector{UInt32}
6768
state::DSFMT_state
6869
vals::Vector{Float64}
69-
idx::Int
70-
71-
function MersenneTwister(seed, state, vals, idx)
72-
length(vals) == MTCacheLength && 0 <= idx <= MTCacheLength ||
73-
throw(DomainError((length(vals), idx),
74-
"`length(vals)` and `idx` must be consistent with $MTCacheLength"))
75-
new(seed, state, vals, idx)
70+
ints::Vector{UInt128}
71+
idxF::Int
72+
idxI::Int
73+
74+
function MersenneTwister(seed, state, vals, ints, idxF, idxI)
75+
length(vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F ||
76+
throw(DomainError((length(vals), idxF),
77+
"`length(vals)` and `idxF` must be consistent with $MT_CACHE_F"))
78+
length(ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I ||
79+
throw(DomainError((length(ints), idxI),
80+
"`length(ints)` and `idxI` must be consistent with $MT_CACHE_I"))
81+
new(seed, state, vals, ints, idxF, idxI)
7682
end
7783
end
7884

7985
MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
80-
MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength)
86+
MersenneTwister(seed, state,
87+
Vector{Float64}(uninitialized, MT_CACHE_F),
88+
Vector{UInt128}(uninitialized, MT_CACHE_I >> 4),
89+
MT_CACHE_F, 0)
8190

8291
"""
8392
MersenneTwister(seed)
@@ -120,27 +129,36 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
120129
copyto!(resize!(dst.seed, length(src.seed)), src.seed)
121130
copy!(dst.state, src.state)
122131
copyto!(dst.vals, src.vals)
123-
dst.idx = src.idx
132+
copyto!(dst.ints, src.ints)
133+
dst.idxF = src.idxF
134+
dst.idxI = src.idxI
124135
dst
125136
end
126137

127138
copy(src::MersenneTwister) =
128-
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), src.idx)
139+
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints),
140+
src.idxF, src.idxI)
141+
129142

130143
==(r1::MersenneTwister, r2::MersenneTwister) =
131-
r1.seed == r2.seed && r1.state == r2.state && isequal(r1.vals, r2.vals) &&
132-
r1.idx == r2.idx
144+
r1.seed == r2.seed && r1.state == r2.state &&
145+
isequal(r1.vals, r2.vals) &&
146+
isequal(r1.ints, r2.ints) &&
147+
r1.idxF == r2.idxF && r1.idxI == r2.idxI
133148

134-
hash(r::MersenneTwister, h::UInt) = foldr(hash, h, (r.seed, r.state, r.vals, r.idx))
149+
hash(r::MersenneTwister, h::UInt) =
150+
foldr(hash, h, (r.seed, r.state, r.vals, r.ints, r.idxF, r.idxI))
135151

136152

137153
### low level API
138154

139-
mt_avail(r::MersenneTwister) = MTCacheLength - r.idx
140-
mt_empty(r::MersenneTwister) = r.idx == MTCacheLength
141-
mt_setfull!(r::MersenneTwister) = r.idx = 0
142-
mt_setempty!(r::MersenneTwister) = r.idx = MTCacheLength
143-
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idx+=1]
155+
#### floats
156+
157+
mt_avail(r::MersenneTwister) = MT_CACHE_F - r.idxF
158+
mt_empty(r::MersenneTwister) = r.idxF == MT_CACHE_F
159+
mt_setfull!(r::MersenneTwister) = r.idxF = 0
160+
mt_setempty!(r::MersenneTwister) = r.idxF = MT_CACHE_F
161+
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idxF+=1]
144162

145163
function gen_rand(r::MersenneTwister)
146164
@gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals))
@@ -149,9 +167,55 @@ end
149167

150168
reserve_1(r::MersenneTwister) = (mt_empty(r) && gen_rand(r); nothing)
151169
# `reserve` allows one to call `rand_inbounds` n times
152-
# precondition: n <= MTCacheLength
170+
# precondition: n <= MT_CACHE_F
153171
reserve(r::MersenneTwister, n::Int) = (mt_avail(r) < n && gen_rand(r); nothing)
154172

173+
#### ints
174+
175+
logsizeof(::Type{<:Union{Bool,Int8,UInt8}}) = 0
176+
logsizeof(::Type{<:Union{Int16,UInt16}}) = 1
177+
logsizeof(::Type{<:Union{Int32,UInt32}}) = 2
178+
logsizeof(::Type{<:Union{Int64,UInt64}}) = 3
179+
logsizeof(::Type{<:Union{Int128,UInt128}}) = 4
180+
181+
idxmask(::Type{<:Union{Bool,Int8,UInt8}}) = 15
182+
idxmask(::Type{<:Union{Int16,UInt16}}) = 7
183+
idxmask(::Type{<:Union{Int32,UInt32}}) = 3
184+
idxmask(::Type{<:Union{Int64,UInt64}}) = 1
185+
idxmask(::Type{<:Union{Int128,UInt128}}) = 0
186+
187+
188+
mt_avail(r::MersenneTwister, ::Type{T}) where {T<:BitInteger} =
189+
r.idxI >> logsizeof(T)
190+
191+
function mt_setfull!(r::MersenneTwister, ::Type{<:BitInteger})
192+
rand!(r, r.ints)
193+
r.idxI = MT_CACHE_I
194+
end
195+
196+
mt_setempty!(r::MersenneTwister, ::Type{<:BitInteger}) = r.idxI = 0
197+
198+
function reserve1(r::MersenneTwister, ::Type{T}) where T<:BitInteger
199+
r.idxI < sizeof(T) && mt_setfull!(r, T)
200+
nothing
201+
end
202+
203+
function mt_pop!(r::MersenneTwister, ::Type{T}) where T<:BitInteger
204+
reserve1(r, T)
205+
r.idxI -= sizeof(T)
206+
i = r.idxI
207+
@inbounds x128 = r.ints[1 + i >> 4]
208+
i128 = (i >> logsizeof(T)) & idxmask(T) # 0-based "indice" in x128
209+
(x128 >> (i128 * (sizeof(T) << 3))) % T
210+
end
211+
#=
212+
function mt_pop!(r::MersenneTwister, ::Type{T}) where {T<:Union{Int128,UInt128}}
213+
reserve1(r, T)
214+
@inbounds res = r.ints[r.idxI >> 4]
215+
r.idxI -= 16
216+
res
217+
end
218+
=#
155219

156220
### seeding
157221

@@ -193,6 +257,9 @@ function srand(r::MersenneTwister, seed::Vector{UInt32})
193257
copyto!(resize!(r.seed, length(seed)), seed)
194258
dsfmt_init_by_array(r.state, r.seed)
195259
mt_setempty!(r)
260+
fill!(r.vals, 0.0) # not strictly necessary, but why not, makes comparing two MT easier
261+
mt_setempty!(r, UInt128)
262+
fill!(r.ints, 0)
196263
return r
197264
end
198265

@@ -243,24 +310,8 @@ rand(r::MersenneTwister, sp::SamplerTrivial{Close1Open2_64}) =
243310

244311
#### integers
245312

246-
rand(r::MersenneTwister,
247-
T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
248-
rand(r, UInt52Raw()) % T[]
249-
250-
function rand(r::MersenneTwister, ::SamplerType{UInt64})
251-
reserve(r, 2)
252-
rand_inbounds(r, UInt52Raw()) << 32 rand_inbounds(r, UInt52Raw())
253-
end
254-
255-
function rand(r::MersenneTwister, ::SamplerType{UInt128})
256-
reserve(r, 3)
257-
xor(rand_inbounds(r, UInt52Raw(UInt128)) << 96,
258-
rand_inbounds(r, UInt52Raw(UInt128)) << 48,
259-
rand_inbounds(r, UInt52Raw(UInt128)))
260-
end
261-
262-
rand(r::MersenneTwister, ::SamplerType{Int64}) = rand(r, UInt64) % Int64
263-
rand(r::MersenneTwister, ::SamplerType{Int128}) = rand(r, UInt128) % Int128
313+
rand(r::MersenneTwister, T::SamplerUnion(BitInteger)) = mt_pop!(r, T[])
314+
rand(r::MersenneTwister, ::SamplerType{Bool}) = rand(r, UInt8) % Bool
264315

265316
#### arrays of floats
266317

@@ -315,13 +366,13 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter
315366
mt_avail(r) == 0 && gen_rand(r)
316367
# from now on, at most one call to gen_rand(r) will be necessary
317368
m = min(n, mt_avail(r))
318-
@gc_preserve r unsafe_copyto!(A.ptr, pointer(r.vals, r.idx+1), m)
369+
@gc_preserve r unsafe_copyto!(A.ptr, pointer(r.vals, r.idxF+1), m)
319370
if m == n
320-
r.idx += m
371+
r.idxF += m
321372
else # m < n
322373
gen_rand(r)
323374
@gc_preserve r unsafe_copyto!(A.ptr+m*sizeof(Float64), pointer(r.vals), n-m)
324-
r.idx = n-m
375+
r.idxF = n-m
325376
end
326377
if I isa CloseOpen
327378
for i=1:n
@@ -470,7 +521,7 @@ end
470521

471522
#### from a range
472523

473-
for T in (Bool, BitInteger_types...) # eval because of ambiguity otherwise
524+
for T in BitInteger_types # eval because of ambiguity otherwise
474525
@eval Sampler(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
475526
SamplerRangeFast(r)
476527
end

base/random/generation.jl

+41-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# Note that the 1) is automated when the sampler is not intended to carry information,
1313
# i.e. the default fall-backs SamplerType and SamplerTrivial are used.
1414

15-
1615
## from types: rand(::Type, [dims...])
1716

1817
### random floats
@@ -101,6 +100,8 @@ rand(rng::AbstractRNG, sp::SamplerBigFloat{T}) where {T<:FloatInterval{BigFloat}
101100

102101
### random integers
103102

103+
#### UniformBits
104+
104105
rand(r::AbstractRNG, ::SamplerTrivial{UInt10Raw{UInt16}}) = rand(r, UInt16)
105106
rand(r::AbstractRNG, ::SamplerTrivial{UInt23Raw{UInt32}}) = rand(r, UInt32)
106107

@@ -111,7 +112,7 @@ _rand52(r::AbstractRNG, ::Type{Float64}) = reinterpret(UInt64, rand(r, Close1Ope
111112
_rand52(r::AbstractRNG, ::Type{UInt64}) = rand(r, UInt64)
112113

113114
rand(r::AbstractRNG, ::SamplerTrivial{UInt104Raw{UInt128}}) =
114-
rand(r, UInt52Raw(UInt128)) << 52 rand_inbounds(r, UInt52Raw(UInt128))
115+
rand(r, UInt52Raw(UInt128)) << 52 rand(r, UInt52Raw(UInt128))
115116

116117
rand(r::AbstractRNG, ::SamplerTrivial{UInt10{UInt16}}) = rand(r, UInt10Raw()) & 0x03ff
117118
rand(r::AbstractRNG, ::SamplerTrivial{UInt23{UInt32}}) = rand(r, UInt23Raw()) & 0x007fffff
@@ -121,6 +122,32 @@ rand(r::AbstractRNG, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())
121122
rand(r::AbstractRNG, sp::SamplerTrivial{<:UniformBits{T}}) where {T} =
122123
rand(r, uint_default(sp[])) % T
123124

125+
#### BitInteger
126+
127+
# rand_generic methods are intended to help RNG implementors with common operations
128+
# we don't call them simply `rand` as this can easily contribute to create
129+
# amibuities with user-side methods (forcing the user to resort to @eval)
130+
131+
rand_generic(r::AbstractRNG, T::Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}) =
132+
rand(r, UInt52Raw()) % T[]
133+
134+
rand_generic(r::AbstractRNG, ::Type{UInt64}) =
135+
rand(r, UInt52Raw()) << 32 rand(r, UInt52Raw())
136+
137+
rand_generic(r::AbstractRNG, ::Type{UInt128}) = _rand128(r, rng_native_52(r))
138+
139+
_rand128(r::AbstractRNG, ::Type{UInt64}) =
140+
((rand(r, UInt64) % UInt128) << 64) rand(r, UInt64)
141+
142+
function _rand128(r::AbstractRNG, ::Type{Float64})
143+
xor(rand(r, UInt52Raw(UInt128)) << 96,
144+
rand(r, UInt52Raw(UInt128)) << 48,
145+
rand(r, UInt52Raw(UInt128)))
146+
end
147+
148+
rand_generic(r::AbstractRNG, ::Type{Int128}) = rand(r, UInt128) % Int128
149+
rand_generic(r::AbstractRNG, ::Type{Int64}) = rand(r, UInt64) % Int64
150+
124151
### random complex numbers
125152

126153
rand(r::AbstractRNG, ::SamplerType{Complex{T}}) where {T<:Real} =
@@ -149,33 +176,34 @@ end
149176

150177
#### helper functions
151178

152-
uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
179+
uint_sup(::Type{<:Base.BitInteger32}) = UInt32
153180
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
154181
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
155182

156183
#### Fast
157184

158-
struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
185+
struct SamplerRangeFast{U<:BitUnsigned,T<:BitInteger} <: Sampler
159186
a::T # first element of the range
160187
bw::UInt # bit width
161188
m::U # range length - 1
162189
mask::U # mask generated values before threshold rejection
163190
end
164191

165-
SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
192+
SamplerRangeFast(r::AbstractUnitRange{T}) where T<:BitInteger =
166193
SamplerRangeFast(r, uint_sup(T))
167194

168195
function SamplerRangeFast(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
169196
isempty(r) && throw(ArgumentError("range must be non-empty"))
170-
m = (last(r) - first(r)) % U
197+
m = (last(r) - first(r)) % unsigned(T) % U # % unsigned(T) to not propagate sign bit
171198
bw = (sizeof(U) << 3 - leading_zeros(m)) % UInt # bit-width
172199
mask = (1 % U << bw) - (1 % U)
173200
SamplerRangeFast{U,T}(first(r), bw, m, mask)
174201
end
175202

176203
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt32,T}) where T
177204
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
178-
x = rand(rng, LessThan(m, Masked(mask, uniform(UInt32))))
205+
# below, we don't use UInt32, to get reproducible values, whether Int is Int64 or Int32
206+
x = rand(rng, LessThan(m, Masked(mask, UInt52Raw(UInt32))))
179207
(x + a % UInt32) % T
180208
end
181209

@@ -215,21 +243,21 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
215243
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
216244
div(sup, k + (k == 0))*k - one(k)
217245

218-
struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
246+
struct SamplerRangeInt{T<:Integer,U<:Unsigned} <: Sampler
219247
a::T # first element of the range
220248
bw::Int # bit width
221249
k::U # range length or zero for full range
222250
u::U # rejection threshold
223251
end
224252

225253

226-
SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
254+
SamplerRangeInt(r::AbstractUnitRange{T}) where T<:BitInteger =
227255
SamplerRangeInt(r, uint_sup(T))
228256

229257
function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
230258
isempty(r) && throw(ArgumentError("range must be non-empty"))
231259
a = first(r)
232-
m = (last(r) - first(r)) % U
260+
m = (last(r) - first(r)) % unsigned(T) % U
233261
k = m + one(U)
234262
bw = (sizeof(U) << 3 - leading_zeros(m)) % Int
235263
mult = if U === UInt32
@@ -247,11 +275,11 @@ function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
247275
end
248276

249277
Sampler(::AbstractRNG, r::AbstractUnitRange{T},
250-
::Repetition) where {T<:Union{Bool,BitInteger}} = SamplerRangeInt(r)
278+
::Repetition) where {T<:BitInteger} = SamplerRangeInt(r)
251279

252-
rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:Union{Bool,BitInteger}} =
253-
(unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, uniform(UInt32))), sp.k)) % T
254280

281+
rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:BitInteger} =
282+
(unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, UInt52Raw(UInt32))), sp.k)) % T
255283

256284
# this function uses 52 bit entropy for small ranges of length <= 2^52
257285
function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt64}) where T<:BitInteger

0 commit comments

Comments
 (0)