Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faster rand!(::MersenneTwister, ::UnitRange) #25047

Merged
merged 3 commits into from
Dec 14, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
and a dedicated method for rand(::SamplerRangeFast{UInt32})
This could make it slightly faster in some cases, by avoiding a branch,
but the main idea so to prepare the unification of code between
this Sampler and SamplerRangeInt
rfourquet committed Dec 12, 2017
commit e058cc0b29a8f0fa59fa2512a36f80b54ff273f9
66 changes: 35 additions & 31 deletions base/random/generation.jl
Original file line number Diff line number Diff line change
@@ -147,6 +147,30 @@ end
# 2) "Default" which tries to use as few entropy bits as possible, at the cost of a
# a bigger upfront price associated with the creation of the sampler

#### helper functions

function rand_lteq(r::AbstractRNG, S, u::U, mask::U) where U<:Integer
while true
x = rand(r, S) & mask
x <= u && return x
end
end

function rand_lteq(rng::AbstractRNG, S, u::T)::T where T
while true
x = rand(rng, S)
x <= u && return x
end
end

# helper function, to turn types to values, should be removed once we
# can do rand(Uniform(UInt))
rand(rng::AbstractRNG, ::Val{T}) where {T} = rand(rng, T)

uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128

#### Fast

struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
@@ -156,32 +180,23 @@ struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler
mask::U # mask generated values before threshold rejection
end

function SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Base.BitInteger64,Bool}
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = last(r) % UInt64 - first(r) % UInt64
bw = (64 - leading_zeros(m)) % UInt # bit-width
mask = (1 % UInt64 << bw) - (1 % UInt64)
SamplerRangeFast{UInt64,T}(first(r), bw, m, mask)
end
SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
SamplerRangeFast(r, uint_sup(T))

function SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Int128,UInt128}
function SamplerRangeFast(r::AbstractUnitRange{T}, ::Type{U}) where {T,U}
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = (last(r)-first(r)) % UInt128
bw = (128 - leading_zeros(m)) % UInt # bit-width
mask = (1 % UInt128 << bw) - (1 % UInt128)
SamplerRangeFast{UInt128,T}(first(r), bw, m, mask)
m = (last(r) - first(r)) % U
bw = (sizeof(U) << 3 - leading_zeros(m)) % UInt # bit-width
mask = (1 % U << bw) - (1 % U)
SamplerRangeFast{U,T}(first(r), bw, m, mask)
end

function rand_lteq(r::AbstractRNG, S, u::U, mask::U) where U<:Integer
while true
x = rand(r, S) & mask
x <= u && return x
end
function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt32,T}) where T
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
x = rand_lteq(rng, Val(UInt32), m, mask)
(x + a % UInt32) % T
end

# helper function, to turn types to values, should be removed once we can do rand(Uniform(UInt))
rand(rng::AbstractRNG, ::Val{T}) where {T} = rand(rng, T)

function rand(rng::AbstractRNG, sp::SamplerRangeFast{UInt64,T}) where T
a, bw, m, mask = sp.a, sp.bw, sp.m, sp.mask
x = bw <= 52 ? rand_lteq(rng, UInt52Raw(), m, mask) :
@@ -216,17 +231,13 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
div(sup, k + (k == 0))*k - one(k)


struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler
a::T # first element of the range
bw::Int # bit width
k::U # range length or zero for full range
u::U # rejection threshold
end

uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32
uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64
uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128

SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} =
SamplerRangeInt(r, uint_sup(T))
@@ -254,13 +265,6 @@ end
Sampler(::AbstractRNG, r::AbstractUnitRange{T},
::Repetition) where {T<:Union{Bool,BitInteger}} = SamplerRangeInt(r)

function rand_lteq(rng::AbstractRNG, S, u::T)::T where T
while true
x = rand(rng, S)
x <= u && return x
end
end

rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:Union{Bool,BitInteger}} =
(unsigned(sp.a) + rem_knuth(rand_lteq(rng, Val(UInt32), sp.u), sp.k)) % T