Skip to content

Commit a02fdc7

Browse files
committed
Merge pull request #9132 from JuliaLang/rf/randn-splitbranch
faster randn by separating out unlikely branch in a function
2 parents f06b4a4 + b99ea92 commit a02fdc7

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

base/random.jl

+21-19
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ type Close1Open2 <: FloatInterval end
5656

5757
@inline rand_ui52_raw_inbounds(r::MersenneTwister) = reinterpret(UInt64, rand_inbounds(r, Close1Open2))
5858
@inline rand_ui52_raw(r::MersenneTwister) = (reserve_1(r); rand_ui52_raw_inbounds(r))
59+
@inline rand_ui52(r::MersenneTwister) = rand_ui52_raw(r) & 0x000fffffffffffff
5960
@inline rand_ui2x52_raw(r::MersenneTwister) = rand_ui52_raw(r) % UInt128 << 64 | rand_ui52_raw(r)
6061

6162
function srand(r::MersenneTwister, seed::Vector{UInt32})
@@ -943,36 +944,37 @@ const ziggurat_nor_r = 3.6541528853610087963519472518
943944
const ziggurat_nor_inv_r = inv(ziggurat_nor_r)
944945
const ziggurat_exp_r = 7.6971174701310497140446280481
945946

946-
@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand(rng, Close1Open2)) & 0x000fffffffffffff
947947

948948
function randmtzig_randn(rng::MersenneTwister=GLOBAL_RNG)
949949
@inbounds begin
950+
r = rand_ui52(rng)
951+
rabs = int64(r>>1) # One bit for the sign
952+
idx = rabs & 0xFF
953+
x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1]
954+
rabs < ki[idx+1] && return x # 99.3% of the time we return here 1st try
955+
return randmtzig_randn_unlikely(rng, idx, rabs, x)
956+
end
957+
end
958+
959+
# this unlikely branch is put in a separate function for better efficiency
960+
function randmtzig_randn_unlikely(rng, idx, rabs, x)
961+
@inbounds if idx == 0
950962
while true
951-
r = randi(rng)
952-
rabs = int64(r>>1) # One bit for the sign
953-
idx = rabs & 0xFF
954-
x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1]
955-
if rabs < ki[idx+1]
956-
return x # 99.3% of the time we return here 1st try
957-
elseif idx == 0
958-
while true
959-
xx = -ziggurat_nor_inv_r*log(rand(rng))
960-
yy = -log(rand(rng))
961-
if yy+yy > xx*xx
962-
return (rabs & 0x100) != 0x000000000 ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
963-
end
964-
end
965-
elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
966-
return x # return from the triangular area
967-
end
963+
xx = -ziggurat_nor_inv_r*log(rand(rng))
964+
yy = -log(rand(rng))
965+
yy+yy > xx*xx && return (rabs >> 8) % Bool ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
968966
end
967+
elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
968+
return x # return from the triangular area
969+
else
970+
return randmtzig_randn(rng)
969971
end
970972
end
971973

972974
function randmtzig_exprnd(rng::MersenneTwister=GLOBAL_RNG)
973975
@inbounds begin
974976
while true
975-
ri = randi(rng)
977+
ri = rand_ui52(rng)
976978
idx = ri & 0xFF
977979
x = ri*we[idx+1]
978980
if ri < ke[idx+1]

0 commit comments

Comments
 (0)