From b99ea9236585c8ff789492b785825002d056bd09 Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Mon, 24 Nov 2014 13:28:40 +0530 Subject: [PATCH] faster randn by separating out unlikely branch in a function All credits to @ViralBShah (cf. #8941 and #9126). This change probably allows better inlining. --- base/random.jl | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/base/random.jl b/base/random.jl index 44911d0b56f06..9d8c7ea0144e5 100644 --- a/base/random.jl +++ b/base/random.jl @@ -56,6 +56,7 @@ type Close1Open2 <: FloatInterval end @inline rand_ui52_raw_inbounds(r::MersenneTwister) = reinterpret(UInt64, rand_inbounds(r, Close1Open2)) @inline rand_ui52_raw(r::MersenneTwister) = (reserve_1(r); rand_ui52_raw_inbounds(r)) +@inline rand_ui52(r::MersenneTwister) = rand_ui52_raw(r) & 0x000fffffffffffff @inline rand_ui2x52_raw(r::MersenneTwister) = rand_ui52_raw(r) % UInt128 << 64 | rand_ui52_raw(r) function srand(r::MersenneTwister, seed::Vector{UInt32}) @@ -943,36 +944,37 @@ const ziggurat_nor_r = 3.6541528853610087963519472518 const ziggurat_nor_inv_r = inv(ziggurat_nor_r) const ziggurat_exp_r = 7.6971174701310497140446280481 -@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand(rng, Close1Open2)) & 0x000fffffffffffff function randmtzig_randn(rng::MersenneTwister=GLOBAL_RNG) @inbounds begin + r = rand_ui52(rng) + rabs = int64(r>>1) # One bit for the sign + idx = rabs & 0xFF + x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1] + rabs < ki[idx+1] && return x # 99.3% of the time we return here 1st try + return randmtzig_randn_unlikely(rng, idx, rabs, x) + end +end + +# this unlikely branch is put in a separate function for better efficiency +function randmtzig_randn_unlikely(rng, idx, rabs, x) + @inbounds if idx == 0 while true - r = randi(rng) - rabs = int64(r>>1) # One bit for the sign - idx = rabs & 0xFF - x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1] - if rabs < ki[idx+1] - return x # 99.3% of the time we return here 1st try - elseif idx == 0 - while true - xx = -ziggurat_nor_inv_r*log(rand(rng)) - yy = -log(rand(rng)) - if yy+yy > xx*xx - return (rabs & 0x100) != 0x000000000 ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx - end - end - elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x) - return x # return from the triangular area - end + xx = -ziggurat_nor_inv_r*log(rand(rng)) + yy = -log(rand(rng)) + yy+yy > xx*xx && return (rabs >> 8) % Bool ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx end + elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x) + return x # return from the triangular area + else + return randmtzig_randn(rng) end end function randmtzig_exprnd(rng::MersenneTwister=GLOBAL_RNG) @inbounds begin while true - ri = randi(rng) + ri = rand_ui52(rng) idx = ri & 0xFF x = ri*we[idx+1] if ri < ke[idx+1]