Skip to content

Commit 417e52e

Browse files
chethegaJeffBezanson
authored andcommitted
Introduce task-local and free-standing xoshiro RNG
1 parent 592db58 commit 417e52e

File tree

8 files changed

+657
-5
lines changed

8 files changed

+657
-5
lines changed

src/jltypes.c

+12-4
Original file line numberDiff line numberDiff line change
@@ -2462,7 +2462,7 @@ void jl_init_types(void) JL_GC_DISABLED
24622462
NULL,
24632463
jl_any_type,
24642464
jl_emptysvec,
2465-
jl_perm_symsvec(10,
2465+
jl_perm_symsvec(14,
24662466
"next",
24672467
"queue",
24682468
"storage",
@@ -2472,8 +2472,12 @@ void jl_init_types(void) JL_GC_DISABLED
24722472
"code",
24732473
"_state",
24742474
"sticky",
2475-
"_isexception"),
2476-
jl_svec(10,
2475+
"_isexception",
2476+
"rngState0",
2477+
"rngState1",
2478+
"rngState2",
2479+
"rngState3"),
2480+
jl_svec(14,
24772481
jl_any_type,
24782482
jl_any_type,
24792483
jl_any_type,
@@ -2483,7 +2487,11 @@ void jl_init_types(void) JL_GC_DISABLED
24832487
jl_any_type,
24842488
jl_uint8_type,
24852489
jl_bool_type,
2486-
jl_bool_type),
2490+
jl_bool_type,
2491+
jl_uint64_type,
2492+
jl_uint64_type,
2493+
jl_uint64_type,
2494+
jl_uint64_type),
24872495
0, 1, 6);
24882496
jl_value_t *listt = jl_new_struct(jl_uniontype_type, jl_task_type, jl_nothing_type);
24892497
jl_svecset(jl_task_type->types, 0, listt);

src/julia.h

+4
Original file line numberDiff line numberDiff line change
@@ -1809,6 +1809,10 @@ typedef struct _jl_task_t {
18091809
uint8_t _state;
18101810
uint8_t sticky; // record whether this Task can be migrated to a new thread
18111811
uint8_t _isexception; // set if `result` is an exception to throw or that we exited with
1812+
uint64_t rngState0; // really rngState[4], but more convenient to split
1813+
uint64_t rngState1;
1814+
uint64_t rngState2;
1815+
uint64_t rngState3;
18121816

18131817
// hidden state:
18141818
// id of owning thread - does not need to be defined until the task runs

src/task.c

+60
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "julia_internal.h"
3636
#include "threading.h"
3737
#include "julia_assert.h"
38+
#include "support/hashing.h"
3839

3940
#ifdef __cplusplus
4041
extern "C" {
@@ -666,6 +667,63 @@ JL_DLLEXPORT void jl_rethrow_other(jl_value_t *e JL_MAYBE_UNROOTED)
666667
throw_internal(NULL);
667668
}
668669

670+
/* This is xoshiro256++ 1.0, used for tasklocal random number generation in julia.
671+
This implementation is intended for embedders and internal use by the runtime, and is
672+
based on the reference implementation on http://prng.di.unimi.it
673+
674+
Credits go to Sebastiano Vigna for coming up with this PRNG.
675+
676+
There is a pure julia implementation in stdlib that tends to be faster when used from
677+
within julia, due to inlining and more agressive architecture-specific optimizations.
678+
*/
679+
JL_DLLEXPORT uint64_t jl_tasklocal_genrandom(jl_task_t *task) JL_NOTSAFEPOINT
680+
{
681+
uint64_t s0 = task->rngState0;
682+
uint64_t s1 = task->rngState1;
683+
uint64_t s2 = task->rngState2;
684+
uint64_t s3 = task->rngState3;
685+
686+
uint64_t t = s0 << 17;
687+
uint64_t tmp = s0 + s3;
688+
uint64_t res = ((tmp << 23) | (tmp >> 41)) + s0;
689+
s2 ^= s0;
690+
s3 ^= s1;
691+
s1 ^= s2;
692+
s0 ^= s3;
693+
s2 ^= t;
694+
s3 = (s3 << 45) | (s3 >> 19);
695+
696+
task->rngState0 = s0;
697+
task->rngState1 = s1;
698+
task->rngState2 = s2;
699+
task->rngState3 = s3;
700+
return res;
701+
}
702+
703+
void rng_split(jl_task_t *from, jl_task_t *to) JL_NOTSAFEPOINT
704+
{
705+
/* TODO: consider a less ad-hoc construction
706+
Ideally we could just use the output of the random stream to seed the initial
707+
state of the child. Out of an overabundance of caution we multiply with
708+
effectively random coefficients, to break possible self-interactions.
709+
710+
It is not the goal to mix bits -- we work under the assumption that the
711+
source is well-seeded, and its output looks effectively random.
712+
However, xoshiro has never been studied in the mode where we seed the
713+
initial state with the output of another xoshiro instance.
714+
715+
Constants have nothing up their sleeve:
716+
0x02011ce34bce797f == hash(UInt(1))|0x01
717+
0x5a94851fb48a6e05 == hash(UInt(2))|0x01
718+
0x3688cf5d48899fa7 == hash(UInt(3))|0x01
719+
0x867b4bb4c42e5661 == hash(UInt(4))|0x01
720+
*/
721+
to->rngState0 = 0x02011ce34bce797f * jl_tasklocal_genrandom(from);
722+
to->rngState1 = 0x5a94851fb48a6e05 * jl_tasklocal_genrandom(from);
723+
to->rngState2 = 0x3688cf5d48899fa7 * jl_tasklocal_genrandom(from);
724+
to->rngState3 = 0x867b4bb4c42e5661 * jl_tasklocal_genrandom(from);
725+
}
726+
669727
JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion_future, size_t ssize)
670728
{
671729
jl_ptls_t ptls = jl_get_ptls_states();
@@ -701,6 +759,8 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion
701759
t->_isexception = 0;
702760
// Inherit logger state from parent task
703761
t->logstate = ptls->current_task->logstate;
762+
// Fork task-local random state from parent
763+
rng_split(ptls->current_task, t);
704764
// there is no active exception handler available on this stack yet
705765
t->eh = NULL;
706766
t->sticky = 1;

stdlib/Random/src/RNGs.jl

+1
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ end
382382

383383
function __init__()
384384
resize!(empty!(THREAD_RNGs), Threads.nthreads()) # ensures that we didn't save a bad object
385+
seed!(TaskLocalRNG())
385386
end
386387

387388

stdlib/Random/src/Random.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export rand!, randn!,
2727
shuffle, shuffle!,
2828
randperm, randperm!,
2929
randcycle, randcycle!,
30-
AbstractRNG, MersenneTwister, RandomDevice
30+
AbstractRNG, MersenneTwister, RandomDevice, TaskLocalRNG, Xoshiro
3131

3232
## general definitions
3333

@@ -292,10 +292,12 @@ rand(r::AbstractRNG, ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(r
292292
rand( ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(X, Dims((d, dims...)))
293293

294294

295+
include("Xoshiro.jl")
295296
include("RNGs.jl")
296297
include("generation.jl")
297298
include("normal.jl")
298299
include("misc.jl")
300+
include("XoshiroSimd.jl")
299301

300302
## rand & rand! & seed! docstrings
301303

stdlib/Random/src/Xoshiro.jl

+208
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
## Xoshiro RNG
4+
# Lots of implementation is shared with TaskLocalRNG
5+
6+
"""
7+
Xoshiro
8+
9+
Xoshiro256++ is a fast pseudorandom number generator originally developed by Sebastian Vigna.
10+
Reference implementation is available at http://prng.di.unimi.it
11+
12+
Apart from the high speed, Xoshiro has a small memory footprint, making it suitable for
13+
applications where many different random states need to be held for long time.
14+
15+
Julia's Xoshiro implementation has a bulk-generation mode; this seeds new virtual PRNGs
16+
from the parent, and uses SIMD to generate in parallel (i.e. the bulk stream consists of
17+
multiple interleaved xoshiro instances).
18+
The virtual PRNGs are discarded once the bulk request has been serviced (and should cause
19+
no heap allocations).
20+
"""
21+
mutable struct Xoshiro <: AbstractRNG
22+
s0::UInt64
23+
s1::UInt64
24+
s2::UInt64
25+
s3::UInt64
26+
end
27+
28+
Xoshiro(::Nothing) = Xoshiro()
29+
30+
function Xoshiro()
31+
parent = RandomDevice()
32+
# Constants have nothing up their sleeve, see task.c
33+
# 0x02011ce34bce797f == hash(UInt(1))|0x01
34+
# 0x5a94851fb48a6e05 == hash(UInt(2))|0x01
35+
# 0x3688cf5d48899fa7 == hash(UInt(3))|0x01
36+
# 0x867b4bb4c42e5661 == hash(UInt(4))|0x01
37+
38+
Xoshiro(0x02011ce34bce797f * rand(parent, UInt64),
39+
0x5a94851fb48a6e05 * rand(parent, UInt64),
40+
0x3688cf5d48899fa7 * rand(parent, UInt64),
41+
0x867b4bb4c42e5661 * rand(parent, UInt64))
42+
end
43+
44+
copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3)
45+
46+
function copy!(dst::Xoshiro, src::Xoshiro)
47+
dst.s0, dst.s1, dst.s2, dst.s3 = src.s0, src.s1, src.s2, src.s3
48+
dst
49+
end
50+
51+
function ==(a::Xoshiro, b::Xoshiro)
52+
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3
53+
end
54+
55+
rng_native_52(::Xoshiro) = UInt64
56+
57+
function seed!(rng::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
58+
# see task.c
59+
s = Base.hash_uint64(s0)
60+
rng.s0 = s
61+
s += Base.hash_uint64(s1)
62+
rng.s1 = s
63+
s += Base.hash_uint64(s2)
64+
rng.s2 = s
65+
s += Base.hash_uint64(s3)
66+
rng.s3 = s
67+
rng
68+
end
69+
70+
@inline function rand(rng::Xoshiro, ::SamplerType{UInt64})
71+
s0, s1, s2, s3 = rng.s0, rng.s1, rng.s2, rng.s3
72+
tmp = s0 + s3
73+
res = tmp << 23 | tmp >> 41
74+
t = s1 << 17
75+
s2 = xor(s2, s0)
76+
s3 = xor(s3, s1)
77+
s1 = xor(s1, s2)
78+
s0 = xor(s0, s3)
79+
s2 = xor(s2, t)
80+
s3 = s3 << 45 | s3 >> 19
81+
rng.s0, rng.s1, rng.s2, rng.s3 = s0, s1, s2, s3
82+
res
83+
end
84+
85+
86+
## Task local RNG
87+
88+
"""
89+
TaskLocalRNG
90+
91+
The `TaskLocalRNG` has state that is local to its task, not its thread.
92+
It is seeded upon task creation, from the state of its parent task.
93+
Therefore, task creation is an event that changes the parent's RNG state.
94+
95+
As an upside, the `TaskLocalRNG` is pretty fast, and permits reproducible
96+
multithreaded simulations (barring race conditions), independent of scheduler
97+
decisions. As long as the number of threads is not used to make decisions on
98+
task creation, simulation results are also independent of the number of available
99+
threads / CPUs. The random stream should not depend on hardware specifics, up to
100+
endianness and possibly word size.
101+
102+
Using or seeding the RNG of any other task than the one returned by `current_task()`
103+
is undefined behavior: it will work most of the time, and may sometimes fail silently.
104+
"""
105+
struct TaskLocalRNG <: AbstractRNG end
106+
TaskLocalRNG(::Nothing) = TaskLocalRNG()
107+
rng_native_52(::TaskLocalRNG) = UInt64
108+
109+
function seed!(rng::TaskLocalRNG, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
110+
# TODO: Consider a less ad-hoc construction
111+
# We can afford burning a handful of cycles here, and we don't want any
112+
# surprises with respect to bad seeds / bad interactions.
113+
t = current_task()
114+
s = hash(s0)
115+
t.rngState0 = s
116+
s += hash(s1)
117+
t.rngState1 = s
118+
s += hash(s2)
119+
t.rngState2 = s
120+
s += hash(s3)
121+
t.rngState3 = s
122+
rand(rng, UInt64)
123+
rand(rng, UInt64)
124+
rand(rng, UInt64)
125+
rand(rng, UInt64)
126+
rng
127+
end
128+
129+
@inline function rand(::TaskLocalRNG, ::SamplerType{UInt64})
130+
task = current_task()
131+
s0, s1, s2, s3 = task.rngState0, task.rngState1, task.rngState2, task.rngState3
132+
tmp = s0 + s3
133+
res = tmp << 23 | tmp >> 41
134+
t = s1 << 17
135+
s2 = xor(s2, s0)
136+
s3 = xor(s3, s1)
137+
s1 = xor(s1, s2)
138+
s0 = xor(s0, s3)
139+
s2 = xor(s2, t)
140+
s3 = s3 << 45 | s3 >> 19
141+
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
142+
res
143+
end
144+
145+
# Shared implementation between Xoshiro and TaskLocalRNG -- seeding
146+
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::UInt128)
147+
seed0 = seed % UInt64
148+
seed1 = (seed>>>64) % UInt64
149+
seed!(rng, seed0, seed1, zero(UInt64), zero(UInt64))
150+
end
151+
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, UInt128(seed))
152+
153+
seed!(rng::Union{TaskLocalRNG, Xoshiro}) =
154+
seed!(rng, rand(RandomDevice(), UInt64), rand(RandomDevice(), UInt64),
155+
rand(RandomDevice(), UInt64), rand(RandomDevice(), UInt64))
156+
157+
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::AbstractVector{UInt64})
158+
if length(seed) > 4
159+
throw(ArgumentError("seed should have no more than 256 bits"))
160+
end
161+
seed0 = length(seed)>0 ? seed[1] : UInt64(0)
162+
seed1 = length(seed)>1 ? seed[2] : UInt64(0)
163+
seed2 = length(seed)>2 ? seed[3] : UInt64(0)
164+
seed3 = length(seed)>3 ? seed[4] : UInt64(0)
165+
seed!(rng, seed0, seed1, seed2, seed3)
166+
end
167+
168+
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::AbstractVector{UInt32})
169+
if iseven(length(seed))
170+
seed!(rng, reinterpret(UInt64, seed))
171+
else
172+
seed!(rng, UInt64[reinterpret(UInt64, @view(seed[begin:end-1])); seed[end] % UInt64])
173+
end
174+
end
175+
176+
@inline function rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt128})
177+
first = rand(rng, UInt64)
178+
second = rand(rng,UInt64)
179+
second + UInt128(first)<<64
180+
end
181+
182+
@inline rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128
183+
184+
@inline rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{T}) where {T<:Union{Bool, UInt8, Int8, UInt16, Int16, UInt32, Int32, Int64}} = rand(rng, UInt64) % T
185+
186+
function copy(rng::TaskLocalRNG)
187+
t = current_task()
188+
Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3)
189+
end
190+
191+
function copy!(dst::TaskLocalRNG, src::Xoshiro)
192+
t = current_task()
193+
t.rngState0, t.rngState1, t.rngState2, t.rngState3 = src.s0, src.s1, src.s2, src.s3
194+
dst
195+
end
196+
197+
function copy!(dst::Xoshiro, src::TaskLocalRNG)
198+
t = current_task()
199+
dst.s0, dst.s1, dst.s2, dst.s3 = t.rngState0, t.rngState1, t.rngState2, t.rngState3
200+
dst
201+
end
202+
203+
function ==(a::Xoshiro, b::TaskLocalRNG)
204+
t = current_task()
205+
a.s0 == t.rngState0 && a.s1 == t.rngState1 && a.s2 == t.rngState2 && a.s3 == t.rngState3
206+
end
207+
208+
==(a::TaskLocalRNG, b::Xoshiro) = b == a

0 commit comments

Comments
 (0)