Skip to content

Commit d5c07f8

Browse files
Kenosimeonschaub
authored andcommitted
Propagate iteration info to optimizer (JuliaLang#36684)
This supersedes JuliaLang#36169. Rather than re-implementing the iteration analysis as done there, this uses the new stmtinfo infrastrcture to propagate all the analysis done during inference all the way to inlining. As a result, it applies not only to splats of singletons, but also to splats of any other short iterable that inference can analyze. E.g.: ``` f(x) = (x...,) @code_typed f(1=>2) @benchmark f(1=>2) ``` Before: ``` julia> @code_typed f(1=>2) CodeInfo( 1 ─ %1 = Core._apply_iterate(Base.iterate, Core.tuple, x)::Tuple{Int64,Int64} └── return %1 ) => Tuple{Int64,Int64} julia> @benchmark f(1=>2) BenchmarkTools.Trial: memory estimate: 96 bytes allocs estimate: 3 -------------- minimum time: 242.659 ns (0.00% GC) median time: 246.904 ns (0.00% GC) mean time: 255.390 ns (1.08% GC) maximum time: 4.415 μs (93.94% GC) -------------- samples: 10000 evals/sample: 405 ``` After: ``` julia> @code_typed f(1=>2) CodeInfo( 1 ─ %1 = Base.getfield(x, 1)::Int64 │ %2 = Base.getfield(x, 2)::Int64 │ %3 = Core.tuple(%1, %2)::Tuple{Int64,Int64} └── return %3 ) => Tuple{Int64,Int64} julia> @benchmark f(1=>2) BenchmarkTools.Trial: memory estimate: 0 bytes allocs estimate: 0 -------------- minimum time: 1.701 ns (0.00% GC) median time: 1.925 ns (0.00% GC) mean time: 1.904 ns (0.00% GC) maximum time: 6.941 ns (0.00% GC) -------------- samples: 10000 evals/sample: 1000 ``` I also implemented the TODO, I had left in JuliaLang#36169 to inline the iterate calls themselves, which gives another 3x improvement over the solution in that PR: ``` julia> @code_typed f(1) CodeInfo( 1 ─ %1 = Core.tuple(x)::Tuple{Int64} └── return %1 ) => Tuple{Int64} julia> @benchmark f(1) BenchmarkTools.Trial: memory estimate: 0 bytes allocs estimate: 0 -------------- minimum time: 1.696 ns (0.00% GC) median time: 1.699 ns (0.00% GC) mean time: 1.702 ns (0.00% GC) maximum time: 5.389 ns (0.00% GC) -------------- samples: 10000 evals/sample: 1000 ``` Fixes JuliaLang#36087 Fixes JuliaLang#29114
1 parent 869440d commit d5c07f8

File tree

6 files changed

+321
-179
lines changed

6 files changed

+321
-179
lines changed

base/compiler/abstractinterpretation.jl

+49-29
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
7878
push!(fullmatch, thisfullmatch)
7979
end
8080
end
81-
info = UnionSplitInfo(splitsigs, infos)
81+
info = UnionSplitInfo(infos)
8282
else
8383
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
8484
if mt === nothing
@@ -505,13 +505,13 @@ end
505505
# returns an array of types
506506
function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(typ), vtypes::VarTable, sv::InferenceState)
507507
if isa(typ, PartialStruct) && typ.typ.name === Tuple.name
508-
return typ.fields
508+
return typ.fields, nothing
509509
end
510510

511511
if isa(typ, Const)
512512
val = typ.val
513513
if isa(val, SimpleVector) || isa(val, Tuple)
514-
return Any[ Const(val[i]) for i in 1:length(val) ] # avoid making a tuple Generator here!
514+
return Any[ Const(val[i]) for i in 1:length(val) ], nothing # avoid making a tuple Generator here!
515515
end
516516
end
517517

@@ -529,27 +529,27 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
529529
if isa(tti, Union)
530530
utis = uniontypes(tti)
531531
if _any(t -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
532-
return Any[Vararg{Any}]
532+
return Any[Vararg{Any}], nothing
533533
end
534534
result = Any[rewrap_unionall(p, tti0) for p in utis[1].parameters]
535535
for t in utis[2:end]
536536
if length(t.parameters) != length(result)
537-
return Any[Vararg{Any}]
537+
return Any[Vararg{Any}], nothing
538538
end
539539
for j in 1:length(t.parameters)
540540
result[j] = tmerge(result[j], rewrap_unionall(t.parameters[j], tti0))
541541
end
542542
end
543-
return result
543+
return result, nothing
544544
elseif tti0 <: Tuple
545545
if isa(tti0, DataType)
546546
if isvatuple(tti0) && length(tti0.parameters) == 1
547-
return Any[Vararg{unwrapva(tti0.parameters[1])}]
547+
return Any[Vararg{unwrapva(tti0.parameters[1])}], nothing
548548
else
549-
return Any[ p for p in tti0.parameters ]
549+
return Any[ p for p in tti0.parameters ], nothing
550550
end
551551
elseif !isa(tti, DataType)
552-
return Any[Vararg{Any}]
552+
return Any[Vararg{Any}], nothing
553553
else
554554
len = length(tti.parameters)
555555
last = tti.parameters[len]
@@ -558,12 +558,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
558558
if va
559559
elts[len] = Vararg{elts[len]}
560560
end
561-
return elts
561+
return elts, nothing
562562
end
563563
elseif tti0 === SimpleVector || tti0 === Any
564-
return Any[Vararg{Any}]
564+
return Any[Vararg{Any}], nothing
565565
elseif tti0 <: Array
566-
return Any[Vararg{eltype(tti0)}]
566+
return Any[Vararg{eltype(tti0)}], nothing
567567
else
568568
return abstract_iteration(interp, itft, typ, vtypes, sv)
569569
end
@@ -572,30 +572,34 @@ end
572572
# simulate iteration protocol on container type up to fixpoint
573573
function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(itertype), vtypes::VarTable, sv::InferenceState)
574574
if !isdefined(Main, :Base) || !isdefined(Main.Base, :iterate) || !isconst(Main.Base, :iterate)
575-
return Any[Vararg{Any}]
575+
return Any[Vararg{Any}], nothing
576576
end
577577
if itft === nothing
578578
iteratef = getfield(Main.Base, :iterate)
579579
itft = Const(iteratef)
580580
elseif isa(itft, Const)
581581
iteratef = itft.val
582582
else
583-
return Any[Vararg{Any}]
583+
return Any[Vararg{Any}], nothing
584584
end
585585
@assert !isvarargtype(itertype)
586-
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv).rt
586+
call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv)
587+
stateordonet = call.rt
588+
info = call.info
587589
# Return Bottom if this is not an iterator.
588590
# WARNING: Changes to the iteration protocol must be reflected here,
589591
# this is not just an optimization.
590-
stateordonet === Bottom && return Any[Bottom]
592+
stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, info)])
591593
valtype = statetype = Bottom
592594
ret = Any[]
595+
calls = CallMeta[call]
596+
593597
# Try to unroll the iteration up to MAX_TUPLE_SPLAT, which covers any finite
594598
# length iterators, or interesting prefix
595599
while true
596600
stateordonet_widened = widenconst(stateordonet)
597601
if stateordonet_widened === Nothing
598-
return ret
602+
return ret, AbstractIterationInfo(calls)
599603
end
600604
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT
601605
break
@@ -607,12 +611,14 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
607611
# If there's no new information in this statetype, don't bother continuing,
608612
# the iterator won't be finite.
609613
if nstatetype statetype
610-
return Any[Bottom]
614+
return Any[Bottom], nothing
611615
end
612616
valtype = getfield_tfunc(stateordonet, Const(1))
613617
push!(ret, valtype)
614618
statetype = nstatetype
615-
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv).rt
619+
call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
620+
stateordonet = call.rt
621+
push!(calls, call)
616622
end
617623
# From here on, we start asking for results on the widened types, rather than
618624
# the precise (potentially const) state type
@@ -629,15 +635,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
629635
if nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype
630636
if typeintersect(stateordonet, Nothing) === Union{}
631637
# Reached a fixpoint, but Nothing is not possible => iterator is infinite or failing
632-
return Any[Bottom]
638+
return Any[Bottom], nothing
633639
end
634640
break
635641
end
636642
valtype = tmerge(valtype, nounion.parameters[1])
637643
statetype = tmerge(statetype, nounion.parameters[2])
638644
end
639645
push!(ret, Vararg{valtype})
640-
return ret
646+
return ret, nothing
641647
end
642648

643649
# do apply(af, fargs...), where af is a function value
@@ -656,13 +662,15 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
656662
nargs = length(aargtypes)
657663
splitunions = 1 < countunionsplit(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM
658664
ctypes = Any[Any[aft]]
665+
infos = [Union{Nothing, AbstractIterationInfo}[]]
659666
for i = 1:nargs
660667
ctypes´ = []
668+
infos′ = []
661669
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
662670
if !isvarargtype(ti)
663-
cti = precise_container_type(interp, itft, ti, vtypes, sv)
671+
cti, info = precise_container_type(interp, itft, ti, vtypes, sv)
664672
else
665-
cti = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv)
673+
cti, info = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv)
666674
# We can't represent a repeating sequence of the same types,
667675
# so tmerge everything together to get one type that represents
668676
# everything.
@@ -678,19 +686,29 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
678686
if _any(t -> t === Bottom, cti)
679687
continue
680688
end
681-
for ct in ctypes
689+
for j = 1:length(ctypes)
690+
ct = ctypes[j]
682691
if isvarargtype(ct[end])
692+
# This is vararg, we're not gonna be able to do any inling,
693+
# drop the info
694+
info = nothing
695+
683696
tail = tuple_tail_elem(unwrapva(ct[end]), cti)
684697
push!(ctypes´, push!(ct[1:(end - 1)], tail))
685698
else
686699
push!(ctypes´, append!(ct[:], cti))
687700
end
701+
push!(infos′, push!(copy(infos[j]), info))
688702
end
689703
end
690704
ctypes = ctypes´
705+
infos = infos′
691706
end
692-
local info = nothing
693-
for ct in ctypes
707+
retinfos = ApplyCallInfo[]
708+
retinfo = UnionSplitApplyCallInfo(retinfos)
709+
for i = 1:length(ctypes)
710+
ct = ctypes[i]
711+
arginfo = infos[i]
694712
lct = length(ct)
695713
# truncate argument list at the first Vararg
696714
for i = 1:lct-1
@@ -701,15 +719,17 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
701719
end
702720
end
703721
call = abstract_call(interp, nothing, ct, vtypes, sv, max_methods)
704-
info = call.info
722+
push!(retinfos, ApplyCallInfo(call.info, arginfo))
705723
res = tmerge(res, call.rt)
706724
if res === Any
725+
# No point carrying forward the info, we're not gonna inline it anyway
726+
retinfo = nothing
707727
break
708728
end
709729
end
710730
# TODO: Add a special info type to capture all the iteration info.
711731
# For now, only propagate info if we don't also union-split the iteration
712-
return CallMeta(res, length(ctypes) == 1 ? info : false)
732+
return CallMeta(res, retinfo)
713733
end
714734

715735
function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
@@ -779,7 +799,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
779799
end
780800
rt = builtin_tfunction(interp, f, argtypes[2:end], sv)
781801
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
782-
cti = precise_container_type(interp, nothing, argtypes[2], vtypes, sv)
802+
cti, _ = precise_container_type(interp, nothing, argtypes[2], vtypes, sv)
783803
idx = argtypes[3].val
784804
if 1 <= idx <= length(cti)
785805
rt = unwrapva(cti[idx])

0 commit comments

Comments
 (0)