Skip to content

Commit 08d223f

Browse files
committed
Propagate iteration info to optimizer
This supersedes #36169. Rather than re-implementing the iteration analysis as done there, this uses the new stmtinfo infrastrcture to propagate all the analysis done during inference all the way to inlining. As a result, it applies not only to splats of singletons, but also to splats of any other short iterable that inference can analyze. E.g.: ``` f(x) = (x...,) @code_typed f(1=>2) @benchmark f(1=>2) ``` Before: ``` julia> @code_typed f(1=>2) CodeInfo( 1 ─ %1 = Core._apply_iterate(Base.iterate, Core.tuple, x)::Tuple{Int64,Int64} └── return %1 ) => Tuple{Int64,Int64} julia> @benchmark f(1=>2) BenchmarkTools.Trial: memory estimate: 96 bytes allocs estimate: 3 -------------- minimum time: 242.659 ns (0.00% GC) median time: 246.904 ns (0.00% GC) mean time: 255.390 ns (1.08% GC) maximum time: 4.415 μs (93.94% GC) -------------- samples: 10000 evals/sample: 405 ``` After: ``` julia> @code_typed f(1=>2) CodeInfo( 1 ─ %1 = Base.getfield(x, 1)::Int64 │ %2 = Base.getfield(x, 2)::Int64 │ %3 = Core.tuple(%1, %2)::Tuple{Int64,Int64} └── return %3 ) => Tuple{Int64,Int64} julia> @benchmark f(1=>2) BenchmarkTools.Trial: memory estimate: 0 bytes allocs estimate: 0 -------------- minimum time: 1.701 ns (0.00% GC) median time: 1.925 ns (0.00% GC) mean time: 1.904 ns (0.00% GC) maximum time: 6.941 ns (0.00% GC) -------------- samples: 10000 evals/sample: 1000 ``` I also implemented the TODO, I had left in #36169 to inline the iterate calls themselves, which gives another 3x improvement over the solution in that PR: ``` julia> @code_typed f(1) CodeInfo( 1 ─ %1 = Core.tuple(x)::Tuple{Int64} └── return %1 ) => Tuple{Int64} julia> @benchmark f(1) BenchmarkTools.Trial: memory estimate: 0 bytes allocs estimate: 0 -------------- minimum time: 1.696 ns (0.00% GC) median time: 1.699 ns (0.00% GC) mean time: 1.702 ns (0.00% GC) maximum time: 5.389 ns (0.00% GC) -------------- samples: 10000 evals/sample: 1000 ``` Fixes #36087 Fixes #29114
1 parent 6f62363 commit 08d223f

File tree

5 files changed

+256
-132
lines changed

5 files changed

+256
-132
lines changed

base/compiler/abstractinterpretation.jl

+50-28
Original file line numberDiff line numberDiff line change
@@ -505,13 +505,13 @@ end
505505
# returns an array of types
506506
function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(typ), vtypes::VarTable, sv::InferenceState)
507507
if isa(typ, PartialStruct) && typ.typ.name === Tuple.name
508-
return typ.fields
508+
return typ.fields, nothing
509509
end
510510

511511
if isa(typ, Const)
512512
val = typ.val
513513
if isa(val, SimpleVector) || isa(val, Tuple)
514-
return Any[ Const(val[i]) for i in 1:length(val) ] # avoid making a tuple Generator here!
514+
return Any[ Const(val[i]) for i in 1:length(val) ], nothing # avoid making a tuple Generator here!
515515
end
516516
end
517517

@@ -529,27 +529,27 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
529529
if isa(tti, Union)
530530
utis = uniontypes(tti)
531531
if _any(t -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
532-
return Any[Vararg{Any}]
532+
return Any[Vararg{Any}], nothing
533533
end
534534
result = Any[rewrap_unionall(p, tti0) for p in utis[1].parameters]
535535
for t in utis[2:end]
536536
if length(t.parameters) != length(result)
537-
return Any[Vararg{Any}]
537+
return Any[Vararg{Any}], nothing
538538
end
539539
for j in 1:length(t.parameters)
540540
result[j] = tmerge(result[j], rewrap_unionall(t.parameters[j], tti0))
541541
end
542542
end
543-
return result
543+
return result, nothing
544544
elseif tti0 <: Tuple
545545
if isa(tti0, DataType)
546546
if isvatuple(tti0) && length(tti0.parameters) == 1
547-
return Any[Vararg{unwrapva(tti0.parameters[1])}]
547+
return Any[Vararg{unwrapva(tti0.parameters[1])}], nothing
548548
else
549-
return Any[ p for p in tti0.parameters ]
549+
return Any[ p for p in tti0.parameters ], nothing
550550
end
551551
elseif !isa(tti, DataType)
552-
return Any[Vararg{Any}]
552+
return Any[Vararg{Any}], nothing
553553
else
554554
len = length(tti.parameters)
555555
last = tti.parameters[len]
@@ -558,12 +558,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
558558
if va
559559
elts[len] = Vararg{elts[len]}
560560
end
561-
return elts
561+
return elts, nothing
562562
end
563563
elseif tti0 === SimpleVector || tti0 === Any
564-
return Any[Vararg{Any}]
564+
return Any[Vararg{Any}], nothing
565565
elseif tti0 <: Array
566-
return Any[Vararg{eltype(tti0)}]
566+
return Any[Vararg{eltype(tti0)}], nothing
567567
else
568568
return abstract_iteration(interp, itft, typ, vtypes, sv)
569569
end
@@ -572,30 +572,35 @@ end
572572
# simulate iteration protocol on container type up to fixpoint
573573
function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(itertype), vtypes::VarTable, sv::InferenceState)
574574
if !isdefined(Main, :Base) || !isdefined(Main.Base, :iterate) || !isconst(Main.Base, :iterate)
575-
return Any[Vararg{Any}]
575+
return Any[Vararg{Any}], nothing
576576
end
577577
if itft === nothing
578578
iteratef = getfield(Main.Base, :iterate)
579579
itft = Const(iteratef)
580580
elseif isa(itft, Const)
581581
iteratef = itft.val
582582
else
583-
return Any[Vararg{Any}]
583+
return Any[Vararg{Any}], nothing
584584
end
585585
@assert !isvarargtype(itertype)
586-
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv).rt
586+
call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv)
587+
stateordonet = call.rt
588+
info = call.info
587589
# Return Bottom if this is not an iterator.
588590
# WARNING: Changes to the iteration protocol must be reflected here,
589591
# this is not just an optimization.
590-
stateordonet === Bottom && return Any[Bottom]
592+
stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(Any[Bottom], Any[info])
591593
valtype = statetype = Bottom
592594
ret = Any[]
595+
states = Any[stateordonet]
596+
infos = Any[info]
597+
593598
# Try to unroll the iteration up to MAX_TUPLE_SPLAT, which covers any finite
594599
# length iterators, or interesting prefix
595600
while true
596601
stateordonet_widened = widenconst(stateordonet)
597602
if stateordonet_widened === Nothing
598-
return ret
603+
return ret, AbstractIterationInfo(states, infos)
599604
end
600605
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT
601606
break
@@ -607,12 +612,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
607612
# If there's no new information in this statetype, don't bother continuing,
608613
# the iterator won't be finite.
609614
if nstatetype statetype
610-
return Any[Bottom]
615+
return Any[Bottom], nothing
611616
end
612617
valtype = getfield_tfunc(stateordonet, Const(1))
613618
push!(ret, valtype)
614619
statetype = nstatetype
615-
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv).rt
620+
call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
621+
stateordonet = call.rt
622+
push!(states, stateordonet)
623+
push!(infos, call.info)
616624
end
617625
# From here on, we start asking for results on the widened types, rather than
618626
# the precise (potentially const) state type
@@ -629,15 +637,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
629637
if nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype
630638
if typeintersect(stateordonet, Nothing) === Union{}
631639
# Reached a fixpoint, but Nothing is not possible => iterator is infinite or failing
632-
return Any[Bottom]
640+
return Any[Bottom], nothing
633641
end
634642
break
635643
end
636644
valtype = tmerge(valtype, nounion.parameters[1])
637645
statetype = tmerge(statetype, nounion.parameters[2])
638646
end
639647
push!(ret, Vararg{valtype})
640-
return ret
648+
return ret, nothing
641649
end
642650

643651
# do apply(af, fargs...), where af is a function value
@@ -656,13 +664,15 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
656664
nargs = length(aargtypes)
657665
splitunions = 1 < countunionsplit(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM
658666
ctypes = Any[Any[aft]]
667+
infos = [Union{Nothing, AbstractIterationInfo}[]]
659668
for i = 1:nargs
660669
ctypes´ = []
670+
infos′ = []
661671
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
662672
if !isvarargtype(ti)
663-
cti = precise_container_type(interp, itft, ti, vtypes, sv)
673+
cti, info = precise_container_type(interp, itft, ti, vtypes, sv)
664674
else
665-
cti = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv)
675+
cti, info = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv)
666676
# We can't represent a repeating sequence of the same types,
667677
# so tmerge everything together to get one type that represents
668678
# everything.
@@ -678,19 +688,29 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
678688
if _any(t -> t === Bottom, cti)
679689
continue
680690
end
681-
for ct in ctypes
691+
for j = 1:length(ctypes)
692+
ct = ctypes[j]
682693
if isvarargtype(ct[end])
694+
# This is vararg, we're not gonna be able to do any inling,
695+
# drop the info
696+
info = nothing
697+
683698
tail = tuple_tail_elem(unwrapva(ct[end]), cti)
684699
push!(ctypes´, push!(ct[1:(end - 1)], tail))
685700
else
686701
push!(ctypes´, append!(ct[:], cti))
687702
end
703+
push!(infos′, push!(copy(infos[j]), info))
688704
end
689705
end
690706
ctypes = ctypes´
707+
infos = infos′
691708
end
692-
local info = nothing
693-
for ct in ctypes
709+
retinfos = ApplyCallInfo[]
710+
retinfo = UnionSplitApplyCallInfo(retinfos)
711+
for i = 1:length(ctypes)
712+
ct = ctypes[i]
713+
arginfo = infos[i]
694714
lct = length(ct)
695715
# truncate argument list at the first Vararg
696716
for i = 1:lct-1
@@ -701,15 +721,17 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
701721
end
702722
end
703723
call = abstract_call(interp, nothing, ct, vtypes, sv, max_methods)
704-
info = call.info
724+
push!(retinfos, ApplyCallInfo(call.info, arginfo))
705725
res = tmerge(res, call.rt)
706726
if res === Any
727+
# No point carrying forward the info, we're not gonna inline it anyway
728+
retinfo = nothing
707729
break
708730
end
709731
end
710732
# TODO: Add a special info type to capture all the iteration info.
711733
# For now, only propagate info if we don't also union-split the iteration
712-
return CallMeta(res, length(ctypes) == 1 ? info : false)
734+
return CallMeta(res, retinfo)
713735
end
714736

715737
function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
@@ -779,7 +801,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
779801
end
780802
rt = builtin_tfunction(interp, f, argtypes[2:end], sv)
781803
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
782-
cti = precise_container_type(interp, nothing, argtypes[2], vtypes, sv)
804+
cti, _ = precise_container_type(interp, nothing, argtypes[2], vtypes, sv)
783805
idx = argtypes[3].val
784806
if 1 <= idx <= length(cti)
785807
rt = unwrapva(cti[idx])

0 commit comments

Comments
 (0)