Skip to content

Commit 9c1b88a

Browse files
committedJan 29, 2018
Tune Stateful iterator
This attempts to address some of the performance regressions observed with the Stateful iterator #25763. It gets most of the way there, but unfortunately still ends up allocating the `Stateful` iterator object rather than propagating through the fields. Getting the rest of the way there will require some compiler tweaks.
1 parent 787550e commit 9c1b88a

File tree

2 files changed

+43
-31
lines changed

2 files changed

+43
-31
lines changed
 

‎base/iterators.jl

+40-28
Original file line numberDiff line numberDiff line change
@@ -1016,44 +1016,56 @@ mutable struct Stateful{T, VS}
10161016
# A bit awkward right now, but adapted to the new iteration protocol
10171017
nextvalstate::Union{VS, Nothing}
10181018
taken::Int
1019-
# Try to find an appropriate type for the (value, state tuple),
1020-
# by doing a recursive unrolling of the iteration protocol up to
1021-
# fixpoint.
1022-
function fixpoint_iter_type(itrT::Type, valT::Type, stateT::Type)
1023-
nextvalstate = Base._return_type(next, Tuple{itrT, stateT})
1024-
nextvalstate <: Tuple{Any, Any} || return Any
1025-
nextvalstate = Tuple{
1026-
typejoin(valT, fieldtype(nextvalstate, 1)),
1027-
typejoin(stateT, fieldtype(nextvalstate, 2))}
1028-
return (Tuple{valT, stateT} == nextvalstate ? nextvalstate :
1029-
fixpoint_iter_type(itrT,
1030-
fieldtype(nextvalstate, 1),
1031-
fieldtype(nextvalstate, 2)))
1032-
end
1033-
function Stateful(itr::T) where {T}
1019+
@inline function Stateful(itr::T) where {T}
10341020
state = start(itr)
10351021
VS = fixpoint_iter_type(T, Union{}, typeof(state))
1036-
vs = done(itr, state) ? nothing : next(itr, state)::VS
1037-
new{T, VS}(itr, vs, 0)
1022+
if done(itr, state)
1023+
new{T, VS}(itr, nothing, 0)
1024+
else
1025+
new{T, VS}(itr, next(itr, state)::VS, 0)
1026+
end
10381027
end
10391028
end
10401029

1030+
# Try to find an appropriate type for the (value, state tuple),
1031+
# by doing a recursive unrolling of the iteration protocol up to
1032+
# fixpoint.
1033+
function fixpoint_iter_type(itrT::Type, valT::Type, stateT::Type)
1034+
nextvalstate = Base._return_type(next, Tuple{itrT, stateT})
1035+
nextvalstate <: Tuple{Any, Any} || return Any
1036+
nextvalstate = Tuple{
1037+
typejoin(valT, fieldtype(nextvalstate, 1)),
1038+
typejoin(stateT, fieldtype(nextvalstate, 2))}
1039+
return (Tuple{valT, stateT} == nextvalstate ? nextvalstate :
1040+
fixpoint_iter_type(itrT,
1041+
fieldtype(nextvalstate, 1),
1042+
fieldtype(nextvalstate, 2)))
1043+
end
1044+
10411045
convert(::Type{Stateful}, itr) = Stateful(itr)
10421046

1043-
isempty(s::Stateful) = s.nextvalstate === nothing
1047+
@inline isempty(s::Stateful) = s.nextvalstate === nothing
10441048

1045-
function popfirst!(s::Stateful)
1046-
isempty(s) && throw(EOFError())
1047-
val, state = s.nextvalstate
1048-
s.nextvalstate = done(s.itr, state) ? nothing : next(s.itr, state)
1049-
s.taken += 1
1050-
val
1049+
@inline function popfirst!(s::Stateful)
1050+
vs = s.nextvalstate
1051+
if vs === nothing
1052+
throw(EOFError())
1053+
else
1054+
val, state = vs
1055+
if done(s.itr, state)
1056+
s.nextvalstate = nothing
1057+
else
1058+
s.nextvalstate = next(s.itr, state)
1059+
end
1060+
s.taken += 1
1061+
return val
1062+
end
10511063
end
10521064

1053-
peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel
1054-
start(s::Stateful) = nothing
1055-
next(s::Stateful, state) = popfirst!(s), nothing
1056-
done(s::Stateful, state) = isempty(s)
1065+
@inline peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel
1066+
@inline start(s::Stateful) = nothing
1067+
@inline next(s::Stateful, state) = popfirst!(s), nothing
1068+
@inline done(s::Stateful, state) = isempty(s)
10571069
IteratorSize(::Type{Stateful{VS,T}} where VS) where {T} =
10581070
isa(IteratorSize(T), SizeUnknown) ? SizeUnknown() : HasLength()
10591071
eltype(::Type{Stateful{VS, T}} where VS) where {T} = eltype(T)

‎test/iterators.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -495,15 +495,15 @@ end
495495
end
496496

497497
@testset "Iterators.Stateful" begin
498-
let a = Iterators.Stateful("abcdef")
498+
let a = @inferred(Iterators.Stateful("abcdef"))
499499
@test !isempty(a)
500500
@test popfirst!(a) == 'a'
501501
@test collect(Iterators.take(a, 3)) == ['b','c','d']
502502
@test collect(a) == ['e', 'f']
503503
end
504-
let a = Iterators.Stateful([1, 1, 1, 2, 3, 4])
504+
let a = @inferred(Iterators.Stateful([1, 1, 1, 2, 3, 4]))
505505
for x in a; x == 1 || break; end
506506
@test Base.peek(a) == 3
507507
@test sum(a) == 7
508508
end
509-
end
509+
end

0 commit comments

Comments
 (0)
Please sign in to comment.