Skip to content

Commit 51ff820

Browse files
nhz2nalimilan
authored andcommitted
[Random] Add s4 field to Xoshiro type (#51332)
This PR adds an optional field to the existing `Xoshiro` struct to be able to faithfully copy the task-local RNG state. Fixes #51255 Redo of #51271 Background context: #49110 added an additional state to the task-local RNG. However, before this PR `copy(default_rng())` did not include this extra state, causing subtle errors in `Test` where `copy(default_rng())` is assumed to contain the full task-local RNG state. (cherry picked from commit 41b41ab)
1 parent a00e2d4 commit 51ff820

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

stdlib/Random/src/Xoshiro.jl

+33-12
Original file line numberDiff line numberDiff line change
@@ -48,28 +48,37 @@ mutable struct Xoshiro <: AbstractRNG
4848
s1::UInt64
4949
s2::UInt64
5050
s3::UInt64
51+
s4::UInt64 # internal splitmix state
5152

52-
Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = new(s0, s1, s2, s3)
53+
Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer, s4::Integer) = new(s0, s1, s2, s3, s4)
54+
Xoshiro(s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64) = new(s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3)
5355
Xoshiro(seed=nothing) = seed!(new(), seed)
5456
end
5557

56-
function setstate!(x::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
58+
Xoshiro(s0::Integer, s1::Integer, s2::Integer, s3::Integer) = Xoshiro(UInt64(s0), UInt64(s1), UInt64(s2), UInt64(s3))
59+
60+
function setstate!(
61+
x::Xoshiro,
62+
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
63+
s4::UInt64, # internal splitmix state
64+
)
5765
x.s0 = s0
5866
x.s1 = s1
5967
x.s2 = s2
6068
x.s3 = s3
69+
x.s4 = s4
6170
x
6271
end
6372

64-
copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3)
73+
copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3, rng.s4)
6574

6675
function copy!(dst::Xoshiro, src::Xoshiro)
67-
dst.s0, dst.s1, dst.s2, dst.s3 = src.s0, src.s1, src.s2, src.s3
76+
dst.s0, dst.s1, dst.s2, dst.s3, dst.s4 = src.s0, src.s1, src.s2, src.s3, src.s4
6877
dst
6978
end
7079

7180
function ==(a::Xoshiro, b::Xoshiro)
72-
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3
81+
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3 && a.s4 == b.s4
7382
end
7483

7584
rng_native_52(::Xoshiro) = UInt64
@@ -116,7 +125,7 @@ rng_native_52(::TaskLocalRNG) = UInt64
116125
function setstate!(
117126
x::TaskLocalRNG,
118127
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
119-
s4::UInt64 = 1s0 + 3s1 + 5s2 + 7s3, # internal splitmix state
128+
s4::UInt64, # internal splitmix state
120129
)
121130
t = current_task()
122131
t.rngState0 = s0
@@ -148,14 +157,20 @@ end
148157
function seed!(rng::Union{TaskLocalRNG,Xoshiro})
149158
# as we get good randomness from RandomDevice, we can skip hashing
150159
rd = RandomDevice()
151-
setstate!(rng, rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64))
160+
s0 = rand(rd, UInt64)
161+
s1 = rand(rd, UInt64)
162+
s2 = rand(rd, UInt64)
163+
s3 = rand(rd, UInt64)
164+
s4 = 1s0 + 3s1 + 5s2 + 7s3
165+
setstate!(rng, s0, s1, s2, s3, s4)
152166
end
153167

154168
function seed!(rng::Union{TaskLocalRNG,Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
155169
c = SHA.SHA2_256_CTX()
156170
SHA.update!(c, reinterpret(UInt8, seed))
157171
s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c))
158-
setstate!(rng, s0, s1, s2, s3)
172+
s4 = 1s0 + 3s1 + 5s2 + 7s3
173+
setstate!(rng, s0, s1, s2, s3, s4)
159174
end
160175

161176
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed))
@@ -178,24 +193,30 @@ end
178193

179194
function copy(rng::TaskLocalRNG)
180195
t = current_task()
181-
Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3)
196+
Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4)
182197
end
183198

184199
function copy!(dst::TaskLocalRNG, src::Xoshiro)
185200
t = current_task()
186-
setstate!(dst, src.s0, src.s1, src.s2, src.s3)
201+
setstate!(dst, src.s0, src.s1, src.s2, src.s3, src.s4)
187202
return dst
188203
end
189204

190205
function copy!(dst::Xoshiro, src::TaskLocalRNG)
191206
t = current_task()
192-
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3)
207+
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4)
193208
return dst
194209
end
195210

196211
function ==(a::Xoshiro, b::TaskLocalRNG)
197212
t = current_task()
198-
a.s0 == t.rngState0 && a.s1 == t.rngState1 && a.s2 == t.rngState2 && a.s3 == t.rngState3
213+
(
214+
a.s0 == t.rngState0 &&
215+
a.s1 == t.rngState1 &&
216+
a.s2 == t.rngState2 &&
217+
a.s3 == t.rngState3 &&
218+
a.s4 == t.rngState4
219+
)
199220
end
200221

201222
==(a::TaskLocalRNG, b::Xoshiro) = b == a

0 commit comments

Comments
 (0)