@@ -60,24 +60,33 @@ srand(rng::RandomDevice) = rng
60
60
61
61
# # MersenneTwister
62
62
63
- const MTCacheLength = dsfmt_get_min_array_size ()
63
+ const MT_CACHE_F = dsfmt_get_min_array_size ()
64
+ const MT_CACHE_I = 501 << 4
64
65
65
66
mutable struct MersenneTwister <: AbstractRNG
66
67
seed:: Vector{UInt32}
67
68
state:: DSFMT_state
68
69
vals:: Vector{Float64}
69
- idx:: Int
70
-
71
- function MersenneTwister (seed, state, vals, idx)
72
- length (vals) == MTCacheLength && 0 <= idx <= MTCacheLength ||
73
- throw (DomainError ((length (vals), idx),
74
- " `length(vals)` and `idx` must be consistent with $MTCacheLength " ))
75
- new (seed, state, vals, idx)
70
+ ints:: Vector{UInt128}
71
+ idxF:: Int
72
+ idxI:: Int
73
+
74
+ function MersenneTwister (seed, state, vals, ints, idxF, idxI)
75
+ length (vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F ||
76
+ throw (DomainError ((length (vals), idxF),
77
+ " `length(vals)` and `idxF` must be consistent with $MT_CACHE_F " ))
78
+ length (ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I ||
79
+ throw (DomainError ((length (ints), idxI),
80
+ " `length(ints)` and `idxI` must be consistent with $MT_CACHE_I " ))
81
+ new (seed, state, vals, ints, idxF, idxI)
76
82
end
77
83
end
78
84
79
85
MersenneTwister (seed:: Vector{UInt32} , state:: DSFMT_state ) =
80
- MersenneTwister (seed, state, zeros (Float64, MTCacheLength), MTCacheLength)
86
+ MersenneTwister (seed, state,
87
+ Vector {Float64} (uninitialized, MT_CACHE_F),
88
+ Vector {UInt128} (uninitialized, MT_CACHE_I >> 4 ),
89
+ MT_CACHE_F, 0 )
81
90
82
91
"""
83
92
MersenneTwister(seed)
@@ -120,27 +129,36 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
120
129
copyto! (resize! (dst. seed, length (src. seed)), src. seed)
121
130
copy! (dst. state, src. state)
122
131
copyto! (dst. vals, src. vals)
123
- dst. idx = src. idx
132
+ copyto! (dst. ints, src. ints)
133
+ dst. idxF = src. idxF
134
+ dst. idxI = src. idxI
124
135
dst
125
136
end
126
137
127
138
copy (src:: MersenneTwister ) =
128
- MersenneTwister (copy (src. seed), copy (src. state), copy (src. vals), src. idx)
139
+ MersenneTwister (copy (src. seed), copy (src. state), copy (src. vals), copy (src. ints),
140
+ src. idxF, src. idxI)
141
+
129
142
130
143
== (r1:: MersenneTwister , r2:: MersenneTwister ) =
131
- r1. seed == r2. seed && r1. state == r2. state && isequal (r1. vals, r2. vals) &&
132
- r1. idx == r2. idx
144
+ r1. seed == r2. seed && r1. state == r2. state &&
145
+ isequal (r1. vals, r2. vals) &&
146
+ isequal (r1. ints, r2. ints) &&
147
+ r1. idxF == r2. idxF && r1. idxI == r2. idxI
133
148
134
- hash (r:: MersenneTwister , h:: UInt ) = foldr (hash, h, (r. seed, r. state, r. vals, r. idx))
149
+ hash (r:: MersenneTwister , h:: UInt ) =
150
+ foldr (hash, h, (r. seed, r. state, r. vals, r. ints, r. idxF, r. idxI))
135
151
136
152
137
153
# ## low level API
138
154
139
- mt_avail (r:: MersenneTwister ) = MTCacheLength - r. idx
140
- mt_empty (r:: MersenneTwister ) = r. idx == MTCacheLength
141
- mt_setfull! (r:: MersenneTwister ) = r. idx = 0
142
- mt_setempty! (r:: MersenneTwister ) = r. idx = MTCacheLength
143
- mt_pop! (r:: MersenneTwister ) = @inbounds return r. vals[r. idx+= 1 ]
155
+ # ### floats
156
+
157
+ mt_avail (r:: MersenneTwister ) = MT_CACHE_F - r. idxF
158
+ mt_empty (r:: MersenneTwister ) = r. idxF == MT_CACHE_F
159
+ mt_setfull! (r:: MersenneTwister ) = r. idxF = 0
160
+ mt_setempty! (r:: MersenneTwister ) = r. idxF = MT_CACHE_F
161
+ mt_pop! (r:: MersenneTwister ) = @inbounds return r. vals[r. idxF+= 1 ]
144
162
145
163
function gen_rand (r:: MersenneTwister )
146
164
@gc_preserve r dsfmt_fill_array_close1_open2! (r. state, pointer (r. vals), length (r. vals))
149
167
150
168
reserve_1 (r:: MersenneTwister ) = (mt_empty (r) && gen_rand (r); nothing )
151
169
# `reserve` allows one to call `rand_inbounds` n times
152
- # precondition: n <= MTCacheLength
170
+ # precondition: n <= MT_CACHE_F
153
171
reserve (r:: MersenneTwister , n:: Int ) = (mt_avail (r) < n && gen_rand (r); nothing )
154
172
173
+ # ### ints
174
+
175
+ logsizeof (:: Type{<:Union{Bool,Int8,UInt8}} ) = 0
176
+ logsizeof (:: Type{<:Union{Int16,UInt16}} ) = 1
177
+ logsizeof (:: Type{<:Union{Int32,UInt32}} ) = 2
178
+ logsizeof (:: Type{<:Union{Int64,UInt64}} ) = 3
179
+ logsizeof (:: Type{<:Union{Int128,UInt128}} ) = 4
180
+
181
+ idxmask (:: Type{<:Union{Bool,Int8,UInt8}} ) = 15
182
+ idxmask (:: Type{<:Union{Int16,UInt16}} ) = 7
183
+ idxmask (:: Type{<:Union{Int32,UInt32}} ) = 3
184
+ idxmask (:: Type{<:Union{Int64,UInt64}} ) = 1
185
+ idxmask (:: Type{<:Union{Int128,UInt128}} ) = 0
186
+
187
+
188
+ mt_avail (r:: MersenneTwister , :: Type{T} ) where {T<: BitInteger } =
189
+ r. idxI >> logsizeof (T)
190
+
191
+ function mt_setfull! (r:: MersenneTwister , :: Type{<:BitInteger} )
192
+ rand! (r, r. ints)
193
+ r. idxI = MT_CACHE_I
194
+ end
195
+
196
+ mt_setempty! (r:: MersenneTwister , :: Type{<:BitInteger} ) = r. idxI = 0
197
+
198
+ function reserve1 (r:: MersenneTwister , :: Type{T} ) where T<: BitInteger
199
+ r. idxI < sizeof (T) && mt_setfull! (r, T)
200
+ nothing
201
+ end
202
+
203
+ function mt_pop! (r:: MersenneTwister , :: Type{T} ) where T<: BitInteger
204
+ reserve1 (r, T)
205
+ r. idxI -= sizeof (T)
206
+ i = r. idxI
207
+ @inbounds x128 = r. ints[1 + i >> 4 ]
208
+ i128 = (i >> logsizeof (T)) & idxmask (T) # 0-based "indice" in x128
209
+ (x128 >> (i128 * (sizeof (T) << 3 ))) % T
210
+ end
211
+ #=
212
+ function mt_pop!(r::MersenneTwister, ::Type{T}) where {T<:Union{Int128,UInt128}}
213
+ reserve1(r, T)
214
+ @inbounds res = r.ints[r.idxI >> 4]
215
+ r.idxI -= 16
216
+ res
217
+ end
218
+ =#
155
219
156
220
# ## seeding
157
221
@@ -193,6 +257,9 @@ function srand(r::MersenneTwister, seed::Vector{UInt32})
193
257
copyto! (resize! (r. seed, length (seed)), seed)
194
258
dsfmt_init_by_array (r. state, r. seed)
195
259
mt_setempty! (r)
260
+ fill! (r. vals, 0.0 ) # not strictly necessary, but why not, makes comparing two MT easier
261
+ mt_setempty! (r, UInt128)
262
+ fill! (r. ints, 0 )
196
263
return r
197
264
end
198
265
@@ -243,24 +310,8 @@ rand(r::MersenneTwister, sp::SamplerTrivial{Close1Open2_64}) =
243
310
244
311
# ### integers
245
312
246
- rand (r:: MersenneTwister ,
247
- T:: SamplerUnion (Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
248
- rand (r, UInt52Raw ()) % T[]
249
-
250
- function rand (r:: MersenneTwister , :: SamplerType{UInt64} )
251
- reserve (r, 2 )
252
- rand_inbounds (r, UInt52Raw ()) << 32 ⊻ rand_inbounds (r, UInt52Raw ())
253
- end
254
-
255
- function rand (r:: MersenneTwister , :: SamplerType{UInt128} )
256
- reserve (r, 3 )
257
- xor (rand_inbounds (r, UInt52Raw (UInt128)) << 96 ,
258
- rand_inbounds (r, UInt52Raw (UInt128)) << 48 ,
259
- rand_inbounds (r, UInt52Raw (UInt128)))
260
- end
261
-
262
- rand (r:: MersenneTwister , :: SamplerType{Int64} ) = rand (r, UInt64) % Int64
263
- rand (r:: MersenneTwister , :: SamplerType{Int128} ) = rand (r, UInt128) % Int128
313
+ rand (r:: MersenneTwister , T:: SamplerUnion (BitInteger)) = mt_pop! (r, T[])
314
+ rand (r:: MersenneTwister , :: SamplerType{Bool} ) = rand (r, UInt8) % Bool
264
315
265
316
# ### arrays of floats
266
317
@@ -315,13 +366,13 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter
315
366
mt_avail (r) == 0 && gen_rand (r)
316
367
# from now on, at most one call to gen_rand(r) will be necessary
317
368
m = min (n, mt_avail (r))
318
- @gc_preserve r unsafe_copyto! (A. ptr, pointer (r. vals, r. idx + 1 ), m)
369
+ @gc_preserve r unsafe_copyto! (A. ptr, pointer (r. vals, r. idxF + 1 ), m)
319
370
if m == n
320
- r. idx += m
371
+ r. idxF += m
321
372
else # m < n
322
373
gen_rand (r)
323
374
@gc_preserve r unsafe_copyto! (A. ptr+ m* sizeof (Float64), pointer (r. vals), n- m)
324
- r. idx = n- m
375
+ r. idxF = n- m
325
376
end
326
377
if I isa CloseOpen
327
378
for i= 1 : n
470
521
471
522
# ### from a range
472
523
473
- for T in (Bool, BitInteger_types... ) # eval because of ambiguity otherwise
524
+ for T in BitInteger_types # eval because of ambiguity otherwise
474
525
@eval Sampler (rng:: MersenneTwister , r:: UnitRange{$T} , :: Val{1} ) =
475
526
SamplerRangeFast (r)
476
527
end
0 commit comments