2
2
3
3
# # RandomDevice
4
4
5
- const BoolBitIntegerType = Union{Type{Bool},Base. BitIntegerType}
6
- const BoolBitIntegerArray = Union{Array{Bool},Base. BitIntegerArray}
5
+ # SamplerUnion(Union{X,Y,...}) == Union{SamplerType{X},SamplerType{Y},...}
6
+ SamplerUnion (U:: Union ) = Union{map (T-> SamplerType{T}, Base. uniontypes (U))... }
7
+ const SamplerBoolBitInteger = SamplerUnion (Union{Bool, Base. BitInteger})
7
8
8
9
if Sys. iswindows ()
9
10
struct RandomDevice <: AbstractRNG
@@ -12,15 +13,9 @@ if Sys.iswindows()
12
13
RandomDevice () = new (Vector {UInt128} (uninitialized, 1 ))
13
14
end
14
15
15
- function rand (rd:: RandomDevice , T :: BoolBitIntegerType )
16
+ function rand (rd:: RandomDevice , sp :: SamplerBoolBitInteger )
16
17
rand! (rd, rd. buffer)
17
- @inbounds return rd. buffer[1 ] % T
18
- end
19
-
20
- function rand! (rd:: RandomDevice , A:: BoolBitIntegerArray )
21
- ccall ((:SystemFunction036 , :Advapi32 ), stdcall, UInt8, (Ptr{Void}, UInt32),
22
- A, sizeof (A))
23
- A
18
+ @inbounds return rd. buffer[1 ] % sp[]
24
19
end
25
20
else # !windows
26
21
struct RandomDevice <: AbstractRNG
@@ -31,10 +26,22 @@ else # !windows
31
26
new (open (unlimited ? " /dev/urandom" : " /dev/random" ), unlimited)
32
27
end
33
28
34
- rand (rd:: RandomDevice , T:: BoolBitIntegerType ) = read ( rd. file, T)
35
- rand! (rd:: RandomDevice , A:: BoolBitIntegerArray ) = read! (rd. file, A)
29
+ rand (rd:: RandomDevice , sp:: SamplerBoolBitInteger ) = read ( rd. file, sp[])
36
30
end # os-test
37
31
32
+ # NOTE: this can't be put within the if-else block above
33
+ for T in (Bool, Base. BitInteger_types... )
34
+ if Sys. iswindows ()
35
+ @eval function rand! (rd:: RandomDevice , A:: Array{$T} , :: SamplerType{$T} )
36
+ ccall ((:SystemFunction036 , :Advapi32 ), stdcall, UInt8, (Ptr{Void}, UInt32),
37
+ A, sizeof (A))
38
+ A
39
+ end
40
+ else
41
+ @eval rand! (rd:: RandomDevice , A:: Array{$T} , :: SamplerType{$T} ) = read! (rd. file, A)
42
+ end
43
+ end
44
+
38
45
"""
39
46
RandomDevice()
40
47
@@ -49,7 +56,7 @@ srand(rng::RandomDevice) = rng
49
56
50
57
# ## generation of floats
51
58
52
- rand (r:: RandomDevice , I :: FloatInterval ) = rand_generic (r, I )
59
+ rand (r:: RandomDevice , sp :: SamplerTrivial{<: FloatInterval} ) = rand_generic (r, sp[] )
53
60
54
61
55
62
# # MersenneTwister
@@ -229,30 +236,31 @@ rand_ui23_raw(r::MersenneTwister) = rand_ui52_raw(r)
229
236
230
237
# ### floats
231
238
232
- rand (r:: MersenneTwister , I:: FloatInterval_64 ) = (reserve_1 (r); rand_inbounds (r, I))
239
+ rand (r:: MersenneTwister , sp:: SamplerTrivial{<:FloatInterval_64} ) =
240
+ (reserve_1 (r); rand_inbounds (r, sp[]))
233
241
234
- rand (r:: MersenneTwister , I :: FloatInterval ) = rand_generic (r, I )
242
+ rand (r:: MersenneTwister , sp :: SamplerTrivial{<: FloatInterval} ) = rand_generic (r, sp[] )
235
243
236
244
# ### integers
237
245
238
- rand (r:: MersenneTwister , T :: Union {Type{Bool}, Type{Int8}, Type{UInt8}, Type{Int16}, Type{UInt16},
239
- Type{ Int32}, Type{ UInt32}} ) =
240
- rand_ui52_raw (r) % T
246
+ rand (r:: MersenneTwister ,
247
+ T :: SamplerUnion (Union{Bool,Int8,UInt8,Int16,UInt16, Int32, UInt32}) ) =
248
+ rand_ui52_raw (r) % T[]
241
249
242
- function rand (r:: MersenneTwister , :: Type {UInt64} )
250
+ function rand (r:: MersenneTwister , :: SamplerType {UInt64} )
243
251
reserve (r, 2 )
244
252
rand_ui52_raw_inbounds (r) << 32 ⊻ rand_ui52_raw_inbounds (r)
245
253
end
246
254
247
- function rand (r:: MersenneTwister , :: Type {UInt128} )
255
+ function rand (r:: MersenneTwister , :: SamplerType {UInt128} )
248
256
reserve (r, 3 )
249
257
xor (rand_ui52_raw_inbounds (r) % UInt128 << 96 ,
250
258
rand_ui52_raw_inbounds (r) % UInt128 << 48 ,
251
259
rand_ui52_raw_inbounds (r))
252
260
end
253
261
254
- rand (r:: MersenneTwister , :: Type {Int64} ) = reinterpret (Int64, rand (r, UInt64))
255
- rand (r:: MersenneTwister , :: Type {Int128} ) = reinterpret (Int128, rand (r, UInt128))
262
+ rand (r:: MersenneTwister , :: SamplerType {Int64} ) = reinterpret (Int64, rand (r, UInt64))
263
+ rand (r:: MersenneTwister , :: SamplerType {Int128} ) = reinterpret (Int128, rand (r, UInt128))
256
264
257
265
# ### arrays of floats
258
266
@@ -278,16 +286,17 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6
278
286
A
279
287
end
280
288
281
- rand! (r:: MersenneTwister , A:: AbstractArray{Float64} ) = rand_AbstractArray_Float64! (r, A)
289
+ rand! (r:: MersenneTwister , A:: AbstractArray{Float64} , I:: SamplerTrivial{<:FloatInterval_64} ) =
290
+ rand_AbstractArray_Float64! (r, A, length (A), I[])
282
291
283
292
fill_array! (s:: DSFMT_state , A:: Ptr{Float64} , n:: Int , :: CloseOpen_64 ) =
284
293
dsfmt_fill_array_close_open! (s, A, n)
285
294
286
295
fill_array! (s:: DSFMT_state , A:: Ptr{Float64} , n:: Int , :: Close1Open2_64 ) =
287
296
dsfmt_fill_array_close1_open2! (s, A, n)
288
297
289
- function rand ! (r:: MersenneTwister , A:: Array{Float64} , n:: Int = length (A) ,
290
- I:: FloatInterval_64 = CloseOpen () )
298
+ function _rand ! (r:: MersenneTwister , A:: Array{Float64} , n:: Int ,
299
+ I:: FloatInterval_64 )
291
300
# depending on the alignment of A, the data written by fill_array! may have
292
301
# to be left-shifted by up to 15 bytes (cf. unsafe_copy! below) for
293
302
# reproducibility purposes;
@@ -317,65 +326,63 @@ function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
317
326
A
318
327
end
319
328
329
+ rand! (r:: MersenneTwister , A:: Array{Float64} , sp:: SamplerTrivial{<:FloatInterval_64} ) =
330
+ _rand! (r, A, length (A), sp[])
331
+
320
332
mask128 (u:: UInt128 , :: Type{Float16} ) =
321
333
(u & 0x03ff03ff03ff03ff03ff03ff03ff03ff ) | 0x3c003c003c003c003c003c003c003c00
322
334
323
335
mask128 (u:: UInt128 , :: Type{Float32} ) =
324
336
(u & 0x007fffff007fffff007fffff007fffff ) | 0x3f8000003f8000003f8000003f800000
325
337
326
- function rand! (r:: MersenneTwister , A:: Union{Array{Float16},Array{Float32}} ,
327
- :: Close1Open2_64 )
328
- T = eltype (A)
329
- n = length (A)
330
- n128 = n * sizeof (T) ÷ 16
331
- Base. @gc_preserve A rand! (r, unsafe_wrap (Array, convert (Ptr{Float64}, pointer (A)), 2 * n128),
332
- 2 * n128, Close1Open2 ())
333
- # FIXME : This code is completely invalid!!!
334
- A128 = unsafe_wrap (Array, convert (Ptr{UInt128}, pointer (A)), n128)
335
- @inbounds for i in 1 : n128
336
- u = A128[i]
337
- u ⊻= u << 26
338
- # at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
339
- # the bit xor, are:
340
- # [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
341
- # the bits needing to be random are
342
- # [1:10, 17:26, 33:42, 49:58] (for Float16)
343
- # [1:23, 33:55] (for Float32)
344
- # this is obviously satisfied on the 32 low bits side, and on the high side,
345
- # the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
346
- # (which are discarded on the low side)
347
- # this is similar for the 64 high bits of u
348
- A128[i] = mask128 (u, T)
349
- end
350
- for i in 16 * n128÷ sizeof (T)+ 1 : n
351
- @inbounds A[i] = rand (r, T) + oneunit (T)
338
+ for T in (Float16, Float32)
339
+ @eval function rand! (r:: MersenneTwister , A:: Array{$T} , :: SamplerTrivial{Close1Open2{$T}} )
340
+ n = length (A)
341
+ n128 = n * sizeof ($ T) ÷ 16
342
+ Base. @gc_preserve A _rand! (r, unsafe_wrap (Array, convert (Ptr{Float64}, pointer (A)), 2 * n128),
343
+ 2 * n128, Close1Open2 ())
344
+ # FIXME : This code is completely invalid!!!
345
+ A128 = unsafe_wrap (Array, convert (Ptr{UInt128}, pointer (A)), n128)
346
+ @inbounds for i in 1 : n128
347
+ u = A128[i]
348
+ u ⊻= u << 26
349
+ # at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
350
+ # the bit xor, are:
351
+ # [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
352
+ # the bits needing to be random are
353
+ # [1:10, 17:26, 33:42, 49:58] (for Float16)
354
+ # [1:23, 33:55] (for Float32)
355
+ # this is obviously satisfied on the 32 low bits side, and on the high side,
356
+ # the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
357
+ # (which are discarded on the low side)
358
+ # this is similar for the 64 high bits of u
359
+ A128[i] = mask128 (u, $ T)
360
+ end
361
+ for i in 16 * n128÷ sizeof ($ T)+ 1 : n
362
+ @inbounds A[i] = rand (r, $ T) + oneunit ($ T)
363
+ end
364
+ A
352
365
end
353
- A
354
- end
355
366
356
- function rand! (r:: MersenneTwister , A:: Union{Array{Float16},Array{Float32}} , :: CloseOpen_64 )
357
- rand! (r, A, Close1Open2 ())
358
- I32 = one (Float32)
359
- for i in eachindex (A)
360
- @inbounds A[i] = Float32 (A[i])- I32 # faster than "A[i] -= one(T)" for T==Float16
367
+ @eval function rand! (r:: MersenneTwister , A:: Array{$T} , :: SamplerTrivial{CloseOpen{$T}} )
368
+ rand! (r, A, Close1Open2 ($ T))
369
+ I32 = one (Float32)
370
+ for i in eachindex (A)
371
+ @inbounds A[i] = Float32 (A[i])- I32 # faster than "A[i] -= one(T)" for T==Float16
372
+ end
373
+ A
361
374
end
362
- A
363
375
end
364
376
365
- rand! (r:: MersenneTwister , A:: Union{Array{Float16},Array{Float32}} ) =
366
- rand! (r, A, CloseOpen ())
367
-
368
377
# ### arrays of integers
369
378
370
- function rand! (r:: MersenneTwister , A:: Array{UInt128} , n:: Int = length (A))
371
- if n > length (A)
372
- throw (BoundsError (A,n))
373
- end
379
+ function rand! (r:: MersenneTwister , A:: Array{UInt128} , :: SamplerType{UInt128} )
380
+ n:: Int = length (A)
374
381
# FIXME : This code is completely invalid!!!
375
382
Af = unsafe_wrap (Array, convert (Ptr{Float64}, pointer (A)), 2 n)
376
383
i = n
377
384
while true
378
- rand ! (r, Af, 2 i, Close1Open2 ())
385
+ _rand ! (r, Af, 2 i, Close1Open2 ())
379
386
n < 5 && break
380
387
i = 0
381
388
@inbounds while n- i >= 5
@@ -396,17 +403,18 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
396
403
A
397
404
end
398
405
399
- # A::Array{UInt128} will match the specialized method above
400
- function rand! (r:: MersenneTwister , A:: Base.BitIntegerArray )
401
- n = length (A)
402
- T = eltype (A)
403
- n128 = n * sizeof (T) ÷ 16
404
- # FIXME : This code is completely invalid!!!
405
- rand! (r, unsafe_wrap (Array, convert (Ptr{UInt128}, pointer (A)), n128))
406
- for i = 16 * n128÷ sizeof (T)+ 1 : n
407
- @inbounds A[i] = rand (r, T)
406
+ for T in Base. BitInteger_types
407
+ T === UInt128 && continue
408
+ @eval function rand! (r:: MersenneTwister , A:: Array{$T} , :: SamplerType{$T} )
409
+ n = length (A)
410
+ n128 = n * sizeof ($ T) ÷ 16
411
+ # FIXME : This code is completely invalid!!!
412
+ rand! (r, unsafe_wrap (Array, convert (Ptr{UInt128}, pointer (A)), n128))
413
+ for i = 16 * n128÷ sizeof ($ T)+ 1 : n
414
+ @inbounds A[i] = rand (r, $ T)
415
+ end
416
+ A
408
417
end
409
- A
410
418
end
411
419
412
420
# ### from a range
@@ -418,7 +426,9 @@ function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
418
426
end
419
427
end
420
428
421
- function rand (rng:: MersenneTwister , r:: UnitRange{T} ) where T<: Union{Base.BitInteger64,Bool}
429
+ function rand (rng:: MersenneTwister ,
430
+ sp:: SamplerTrivial{UnitRange{T}} ) where T<: Union{Base.BitInteger64,Bool}
431
+ r = sp[]
422
432
isempty (r) && throw (ArgumentError (" range must be non-empty" ))
423
433
m = last (r) % UInt64 - first (r) % UInt64
424
434
bw = (64 - leading_zeros (m)) % UInt # bit-width
@@ -428,7 +438,9 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInte
428
438
(x + first (r) % UInt64) % T
429
439
end
430
440
431
- function rand (rng:: MersenneTwister , r:: UnitRange{T} ) where T<: Union{Int128,UInt128}
441
+ function rand (rng:: MersenneTwister ,
442
+ sp:: SamplerTrivial{UnitRange{T}} ) where T<: Union{Int128,UInt128}
443
+ r = sp[]
432
444
isempty (r) && throw (ArgumentError (" range must be non-empty" ))
433
445
m = (last (r)- first (r)) % UInt128
434
446
bw = (128 - leading_zeros (m)) % UInt # bit-width
@@ -439,6 +451,11 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt1
439
451
x % T + first (r)
440
452
end
441
453
454
+ for T in (Bool, Base. BitInteger_types... ) # eval because of ambiguity otherwise
455
+ @eval Sampler (rng:: MersenneTwister , r:: UnitRange{$T} , :: Val{1} ) =
456
+ SamplerTrivial (r)
457
+ end
458
+
442
459
443
460
# ## randjump
444
461
0 commit comments