2
2
3
3
# # RandomDevice
4
4
5
- const BoolBitIntegerType = Union{Type{Bool},Base. BitIntegerType}
6
- const BoolBitIntegerArray = Union{Array{Bool},Base. BitIntegerArray}
5
+
6
+ StateTypes (U:: Union ) = Union{map (T-> StateType{T}, Base. uniontypes (U))... }
7
+ const StateBoolBitInteger = StateTypes (Union{Bool, Base. BitInteger})
7
8
8
9
if Sys. iswindows ()
9
10
struct RandomDevice <: AbstractRNG
@@ -12,15 +13,9 @@ if Sys.iswindows()
12
13
RandomDevice () = new (Vector {UInt128} (1 ))
13
14
end
14
15
15
- function rand (rd:: RandomDevice , T :: BoolBitIntegerType )
16
+ function rand (rd:: RandomDevice , st :: StateBoolBitInteger )
16
17
rand! (rd, rd. buffer)
17
- @inbounds return rd. buffer[1 ] % T
18
- end
19
-
20
- function rand! (rd:: RandomDevice , A:: BoolBitIntegerArray )
21
- ccall ((:SystemFunction036 , :Advapi32 ), stdcall, UInt8, (Ptr{Void}, UInt32),
22
- A, sizeof (A))
23
- A
18
+ @inbounds return rd. buffer[1 ] % st[]
24
19
end
25
20
else # !windows
26
21
struct RandomDevice <: AbstractRNG
@@ -31,10 +26,22 @@ else # !windows
31
26
new (open (unlimited ? " /dev/urandom" : " /dev/random" ), unlimited)
32
27
end
33
28
34
- rand (rd:: RandomDevice , T:: BoolBitIntegerType ) = read ( rd. file, T)
35
- rand! (rd:: RandomDevice , A:: BoolBitIntegerArray ) = read! (rd. file, A)
29
+ rand (rd:: RandomDevice , st:: StateBoolBitInteger ) = read ( rd. file, st[])
36
30
end # os-test
37
31
32
+ # NOTE: this can't be put in within the if-else block above
33
+ for T in (Bool, Base. BitInteger_types... )
34
+ if Sys. iswindows ()
35
+ @eval function rand! (rd:: RandomDevice , A:: Array{$T} , :: StateType{$T} )
36
+ ccall ((:SystemFunction036 , :Advapi32 ), stdcall, UInt8, (Ptr{Void}, UInt32),
37
+ A, sizeof (A))
38
+ A
39
+ end
40
+ else
41
+ @eval rand! (rd:: RandomDevice , A:: Array{$T} , :: StateType{$T} ) = read! (rd. file, A)
42
+ end
43
+ end
44
+
38
45
"""
39
46
RandomDevice()
40
47
@@ -49,7 +56,7 @@ srand(rng::RandomDevice) = rng
49
56
50
57
# ## generation of floats
51
58
52
- rand (r:: RandomDevice , I :: FloatInterval ) = rand_generic (r, I )
59
+ rand (r:: RandomDevice , st :: StateTrivial{<: FloatInterval} ) = rand_generic (r, st[] )
53
60
54
61
55
62
# # MersenneTwister
@@ -229,30 +236,30 @@ rand_ui23_raw(r::MersenneTwister) = rand_ui52_raw(r)
229
236
230
237
# ### floats
231
238
232
- rand (r:: MersenneTwister , I :: FloatInterval_64 ) = (reserve_1 (r); rand_inbounds (r, I ))
239
+ rand (r:: MersenneTwister , st :: StateTrivial{<: FloatInterval_64} ) = (reserve_1 (r); rand_inbounds (r, st[] ))
233
240
234
- rand (r:: MersenneTwister , I :: FloatInterval ) = rand_generic (r, I )
241
+ rand (r:: MersenneTwister , st :: StateTrivial{<: FloatInterval} ) = rand_generic (r, st[] )
235
242
236
243
# ### integers
237
244
238
245
rand (r:: MersenneTwister ,
239
- :: Type{T} ) where {T <: Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32} } =
240
- rand_ui52_raw (r) % T
246
+ T :: StateTypes ( Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
247
+ rand_ui52_raw (r) % T[]
241
248
242
- function rand (r:: MersenneTwister , :: Type {UInt64} )
249
+ function rand (r:: MersenneTwister , :: StateType {UInt64} )
243
250
reserve (r, 2 )
244
251
rand_ui52_raw_inbounds (r) << 32 ⊻ rand_ui52_raw_inbounds (r)
245
252
end
246
253
247
- function rand (r:: MersenneTwister , :: Type {UInt128} )
254
+ function rand (r:: MersenneTwister , :: StateType {UInt128} )
248
255
reserve (r, 3 )
249
256
xor (rand_ui52_raw_inbounds (r) % UInt128 << 96 ,
250
257
rand_ui52_raw_inbounds (r) % UInt128 << 48 ,
251
258
rand_ui52_raw_inbounds (r))
252
259
end
253
260
254
- rand (r:: MersenneTwister , :: Type {Int64} ) = reinterpret (Int64, rand (r, UInt64))
255
- rand (r:: MersenneTwister , :: Type {Int128} ) = reinterpret (Int128, rand (r, UInt128))
261
+ rand (r:: MersenneTwister , :: StateType {Int64} ) = reinterpret (Int64, rand (r, UInt64))
262
+ rand (r:: MersenneTwister , :: StateType {Int128} ) = reinterpret (Int128, rand (r, UInt128))
256
263
257
264
# ### arrays of floats
258
265
@@ -278,16 +285,17 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6
278
285
A
279
286
end
280
287
281
- rand! (r:: MersenneTwister , A:: AbstractArray{Float64} ) = rand_AbstractArray_Float64! (r, A)
288
+ rand! (r:: MersenneTwister , A:: AbstractArray{Float64} , I:: StateTrivial{<:FloatInterval_64} ) =
289
+ rand_AbstractArray_Float64! (r, A, length (A), I[])
282
290
283
291
fill_array! (s:: DSFMT_state , A:: Ptr{Float64} , n:: Int , :: CloseOpen_64 ) =
284
292
dsfmt_fill_array_close_open! (s, A, n)
285
293
286
294
fill_array! (s:: DSFMT_state , A:: Ptr{Float64} , n:: Int , :: Close1Open2_64 ) =
287
295
dsfmt_fill_array_close1_open2! (s, A, n)
288
296
289
- function rand ! (r:: MersenneTwister , A:: Array{Float64} , n:: Int = length (A) ,
290
- I:: FloatInterval_64 = CloseOpen () )
297
+ function _rand ! (r:: MersenneTwister , A:: Array{Float64} , n:: Int ,
298
+ I:: FloatInterval_64 )
291
299
# depending on the alignment of A, the data written by fill_array! may have
292
300
# to be left-shifted by up to 15 bytes (cf. unsafe_copy! below) for
293
301
# reproducibility purposes;
@@ -317,65 +325,63 @@ function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
317
325
A
318
326
end
319
327
328
+ rand! (r:: MersenneTwister , A:: Array{Float64} , st:: StateTrivial{<:FloatInterval_64} ) =
329
+ _rand! (r, A, length (A), st[])
330
+
320
331
mask128 (u:: UInt128 , :: Type{Float16} ) =
321
332
(u & 0x03ff03ff03ff03ff03ff03ff03ff03ff ) | 0x3c003c003c003c003c003c003c003c00
322
333
323
334
mask128 (u:: UInt128 , :: Type{Float32} ) =
324
335
(u & 0x007fffff007fffff007fffff007fffff ) | 0x3f8000003f8000003f8000003f800000
325
336
326
- function rand! (r:: MersenneTwister , A:: Union{Array{Float16},Array{Float32}} ,
327
- :: Close1Open2_64 )
328
- T = eltype (A)
329
- n = length (A)
330
- n128 = n * sizeof (T) ÷ 16
331
- Base. @gc_preserve A rand! (r, unsafe_wrap (Array, convert (Ptr{Float64}, pointer (A)), 2 * n128),
332
- 2 * n128, Close1Open2 ())
333
- # FIXME : This code is completely invalid!!!
334
- A128 = unsafe_wrap (Array, convert (Ptr{UInt128}, pointer (A)), n128)
335
- @inbounds for i in 1 : n128
336
- u = A128[i]
337
- u ⊻= u << 26
338
- # at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
339
- # the bit xor, are:
340
- # [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
341
- # the bits needing to be random are
342
- # [1:10, 17:26, 33:42, 49:58] (for Float16)
343
- # [1:23, 33:55] (for Float32)
344
- # this is obviously satisfied on the 32 low bits side, and on the high side,
345
- # the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
346
- # (which are discarded on the low side)
347
- # this is similar for the 64 high bits of u
348
- A128[i] = mask128 (u, T)
349
- end
350
- for i in 16 * n128÷ sizeof (T)+ 1 : n
351
- @inbounds A[i] = rand (r, T) + oneunit (T)
337
+ for T in (Float16, Float32)
338
+ @eval function rand! (r:: MersenneTwister , A:: Array{$T} , :: StateTrivial{Close1Open2{$T}} )
339
+ n = length (A)
340
+ n128 = n * sizeof ($ T) ÷ 16
341
+ Base. @gc_preserve A _rand! (r, unsafe_wrap (Array, convert (Ptr{Float64}, pointer (A)), 2 * n128),
342
+ 2 * n128, Close1Open2 ())
343
+ # FIXME : This code is completely invalid!!!
344
+ A128 = unsafe_wrap (Array, convert (Ptr{UInt128}, pointer (A)), n128)
345
+ @inbounds for i in 1 : n128
346
+ u = A128[i]
347
+ u ⊻= u << 26
348
+ # at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
349
+ # the bit xor, are:
350
+ # [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
351
+ # the bits needing to be random are
352
+ # [1:10, 17:26, 33:42, 49:58] (for Float16)
353
+ # [1:23, 33:55] (for Float32)
354
+ # this is obviously satisfied on the 32 low bits side, and on the high side,
355
+ # the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
356
+ # (which are discarded on the low side)
357
+ # this is similar for the 64 high bits of u
358
+ A128[i] = mask128 (u, $ T)
359
+ end
360
+ for i in 16 * n128÷ sizeof ($ T)+ 1 : n
361
+ @inbounds A[i] = rand (r, $ T) + oneunit ($ T)
362
+ end
363
+ A
352
364
end
353
- A
354
- end
355
365
356
- function rand! (r:: MersenneTwister , A:: Union{Array{Float16},Array{Float32}} , :: CloseOpen_64 )
357
- rand! (r, A, Close1Open2 ())
358
- I32 = one (Float32)
359
- for i in eachindex (A)
360
- @inbounds A[i] = Float32 (A[i])- I32 # faster than "A[i] -= one(T)" for T==Float16
366
+ @eval function rand! (r:: MersenneTwister , A:: Array{$T} , :: StateTrivial{CloseOpen{$T}} )
367
+ rand! (r, A, Close1Open2 ($ T))
368
+ I32 = one (Float32)
369
+ for i in eachindex (A)
370
+ @inbounds A[i] = Float32 (A[i])- I32 # faster than "A[i] -= one(T)" for T==Float16
371
+ end
372
+ A
361
373
end
362
- A
363
374
end
364
375
365
- rand! (r:: MersenneTwister , A:: Union{Array{Float16},Array{Float32}} ) =
366
- rand! (r, A, CloseOpen ())
367
-
368
376
# ### arrays of integers
369
377
370
- function rand! (r:: MersenneTwister , A:: Array{UInt128} , n:: Int = length (A))
371
- if n > length (A)
372
- throw (BoundsError (A,n))
373
- end
378
+ function rand! (r:: MersenneTwister , A:: Array{UInt128} , :: StateType{UInt128} )
379
+ n:: Int = length (A)
374
380
# FIXME : This code is completely invalid!!!
375
381
Af = unsafe_wrap (Array, convert (Ptr{Float64}, pointer (A)), 2 n)
376
382
i = n
377
383
while true
378
- rand ! (r, Af, 2 i, Close1Open2 ())
384
+ _rand ! (r, Af, 2 i, Close1Open2 ())
379
385
n < 5 && break
380
386
i = 0
381
387
@inbounds while n- i >= 5
@@ -396,17 +402,18 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
396
402
A
397
403
end
398
404
399
- # A::Array{UInt128} will match the specialized method above
400
- function rand! (r:: MersenneTwister , A:: Base.BitIntegerArray )
401
- n = length (A)
402
- T = eltype (A)
403
- n128 = n * sizeof (T) ÷ 16
404
- # FIXME : This code is completely invalid!!!
405
- rand! (r, unsafe_wrap (Array, convert (Ptr{UInt128}, pointer (A)), n128))
406
- for i = 16 * n128÷ sizeof (T)+ 1 : n
407
- @inbounds A[i] = rand (r, T)
405
+ for T in Base. BitInteger_types
406
+ T === UInt128 && continue
407
+ @eval function rand! (r:: MersenneTwister , A:: Array{$T} , :: StateType{$T} )
408
+ n = length (A)
409
+ n128 = n * sizeof ($ T) ÷ 16
410
+ # FIXME : This code is completely invalid!!!
411
+ rand! (r, unsafe_wrap (Array, convert (Ptr{UInt128}, pointer (A)), n128))
412
+ for i = 16 * n128÷ sizeof ($ T)+ 1 : n
413
+ @inbounds A[i] = rand (r, $ T)
414
+ end
415
+ A
408
416
end
409
- A
410
417
end
411
418
412
419
# ### from a range
@@ -418,7 +425,9 @@ function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
418
425
end
419
426
end
420
427
421
- function rand (rng:: MersenneTwister , r:: UnitRange{T} ) where T<: Union{Base.BitInteger64,Bool}
428
+ function rand (rng:: MersenneTwister ,
429
+ st:: StateTrivial{UnitRange{T}} ) where T<: Union{Base.BitInteger64,Bool}
430
+ r = st[]
422
431
isempty (r) && throw (ArgumentError (" range must be non-empty" ))
423
432
m = last (r) % UInt64 - first (r) % UInt64
424
433
bw = (64 - leading_zeros (m)) % UInt # bit-width
@@ -428,7 +437,9 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInte
428
437
(x + first (r) % UInt64) % T
429
438
end
430
439
431
- function rand (rng:: MersenneTwister , r:: UnitRange{T} ) where T<: Union{Int128,UInt128}
440
+ function rand (rng:: MersenneTwister ,
441
+ st:: StateTrivial{UnitRange{T}} ) where T<: Union{Int128,UInt128}
442
+ r = st[]
432
443
isempty (r) && throw (ArgumentError (" range must be non-empty" ))
433
444
m = (last (r)- first (r)) % UInt128
434
445
bw = (128 - leading_zeros (m)) % UInt # bit-width
@@ -439,6 +450,11 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt1
439
450
x % T + first (r)
440
451
end
441
452
453
+ for T in (Bool, Base. BitInteger_types... ) # eval because of ambiguity otherwise
454
+ @eval State (rng:: MersenneTwister , r:: UnitRange{$T} , :: Val{1} ) =
455
+ StateTrivial (r)
456
+ end
457
+
442
458
443
459
# ## randjump
444
460
0 commit comments