Skip to content

Commit 997f877

Browse files
committed
fix #42078, improve the idempotency of callsite inlining
After #41328, inference can observe statement flags and try to re-infer a discarded source if it's going to be inlined. The re-inferred source will only be cached into the inference-local cache, and won't be cached globally.
1 parent 3041991 commit 997f877

File tree

6 files changed

+105
-52
lines changed

6 files changed

+105
-52
lines changed

base/compiler/abstractinterpretation.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -546,10 +546,9 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
546546
end
547547
end
548548
inf_result = InferenceResult(mi, argtypes, va_override)
549-
frame = InferenceState(inf_result, #=cache=#false, interp)
549+
frame = InferenceState(inf_result, #=cache=#:local, interp)
550550
frame === nothing && return nothing # this is probably a bad generated function (unsound), but just ignore it
551551
frame.parent = sv
552-
push!(inf_cache, inf_result)
553552
typeinf(interp, frame) || return nothing
554553
end
555554
result = inf_result.result
@@ -592,7 +591,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
592591
return nothing
593592
end
594593
mi = mi::MethodInstance
595-
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, sv)
594+
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, argtypes, sv)
596595
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
597596
return nothing
598597
end
@@ -696,7 +695,9 @@ end
696695
# This is a heuristic to avoid trying to const prop through complicated functions
697696
# where we would spend a lot of time, but are probably unlikely to get an improved
698697
# result anyway.
699-
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance, sv::InferenceState)
698+
function const_prop_methodinstance_heuristic(
699+
interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance,
700+
argtypes::Vector{Any}, sv::InferenceState)
700701
method = match.method
701702
if method.is_for_opaque_closure
702703
# Not inlining an opaque closure can be very expensive, so be generous
@@ -715,7 +716,7 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match:
715716
if isdefined(code, :inferred) && !cache_inlineable
716717
cache_inf = code.inferred
717718
if !(cache_inf === nothing)
718-
src = inlining_policy(interp, cache_inf, get_curr_ssaflag(sv))
719+
src = inlining_policy(interp, cache_inf, get_curr_ssaflag(sv), mi, argtypes)
719720
cache_inlineable = src !== nothing
720721
end
721722
end

base/compiler/inferencestate.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ mutable struct InferenceState
5353

5454
# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
5555
function InferenceState(result::InferenceResult, src::CodeInfo,
56-
cached::Bool, interp::AbstractInterpreter)
56+
cache::Symbol, interp::AbstractInterpreter)
5757
(; def) = linfo = result.linfo
58-
code = src.code::Array{Any,1}
58+
code = src.code::Vector{Any}
5959

6060
sp = sptypes_from_meth_instance(linfo::MethodInstance)
6161

@@ -92,6 +92,7 @@ mutable struct InferenceState
9292
valid_worlds = WorldRange(src.min_world,
9393
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
9494

95+
@assert cache === :no || cache === :local || cache === :global
9596
frame = new(
9697
InferenceParams(interp), result, linfo,
9798
sp, slottypes, mod, 0,
@@ -103,11 +104,11 @@ mutable struct InferenceState
103104
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
104105
Vector{InferenceState}(), # callers_in_cycle
105106
#=parent=#nothing,
106-
cached, false, false,
107+
cache === :global, false, false,
107108
CachedMethodTable(method_table(interp)),
108109
interp)
109110
result.result = frame
110-
cached && push!(get_inference_cache(interp), result)
111+
cache !== :no && push!(get_inference_cache(interp), result)
111112
return frame
112113
end
113114
end
@@ -222,12 +223,12 @@ end
222223

223224
method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table
224225

225-
function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)
226+
function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
226227
# prepare an InferenceState object for inferring lambda
227228
src = retrieve_code_info(result.linfo)
228229
src === nothing && return nothing
229230
validate_code_in_debug_mode(result.linfo, src, "lowered")
230-
return InferenceState(result, src, cached, interp)
231+
return InferenceState(result, src, cache, interp)
231232
end
232233

233234
function sptypes_from_meth_instance(linfo::MethodInstance)

base/compiler/optimize.jl

+18-8
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,30 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T, I<:AbstractInterpreter
2828
interp::I
2929
end
3030

31-
function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_flag::UInt8)
31+
function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_flag::UInt8,
32+
mi::MethodInstance, argtypes::Vector{Any})
3233
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
3334
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
3435
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
3536
return src_inferred && src_inlineable ? src : nothing
3637
elseif isa(src, OptimizationState) && isdefined(src, :ir)
3738
return (is_stmt_inline(stmt_flag) || src.src.inlineable) ? src.ir : nothing
38-
else
39-
# maybe we want to make inference keep the source in a local cache if a statement is going to inlined
40-
# and re-optimize it here with disabling further inlining to avoid infinite optimization loop
41-
# (we can even naively try to re-infer it entirely)
42-
# but it seems like that "single-level-inlining" is more trouble and complex than it's worth
43-
# see https://github.com/JuliaLang/julia/pull/41328/commits/0fc0f71a42b8c9d04b0dafabf3f1f17703abf2e7
44-
return nothing
39+
elseif src === nothing && is_stmt_inline(stmt_flag)
40+
# if this statement is forced to be inlined, make an additional effort to find the
41+
# inferred source in the local cache
42+
# we still won't find a source for recursive call because the "single-level" inlining
43+
# seems to be more trouble and complex than it's worth
44+
inf_result = cache_lookup(mi, argtypes, get_inference_cache(interp))
45+
inf_result === nothing && return nothing
46+
src = inf_result.src
47+
if isa(src, CodeInfo)
48+
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
49+
return src_inferred ? src : nothing
50+
elseif isa(src, OptimizationState)
51+
return isdefined(src, :ir) ? src.ir : nothing
52+
else
53+
return nothing
54+
end
4555
end
4656
end
4757

base/compiler/ssair/inlining.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ end
722722

723723
function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
724724
mi = todo.mi
725-
(; match) = todo.spec::DelayedInliningSpec
725+
(; match, atypes) = todo.spec::DelayedInliningSpec
726726

727727
#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
728728
isconst, src = false, nothing
@@ -757,7 +757,7 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
757757
return ConstantCase(src)
758758
end
759759

760-
src = inlining_policy(state.interp, src, flag)
760+
src = inlining_policy(state.interp, src, flag, mi, atypes)
761761

762762
if src === nothing
763763
return compileable_specialization(et, match)

base/compiler/typeinfer.jl

+34-30
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
# build (and start inferring) the inference frame for the top-level MethodInstance
4-
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cached::Bool)
5-
frame = InferenceState(result, cached, interp)
4+
function typeinf(interp::AbstractInterpreter, result::InferenceResult, cache::Symbol)
5+
frame = InferenceState(result, cache, interp)
66
frame === nothing && return false
7-
cached && lock_mi_inference(interp, result.linfo)
7+
cache === :global && lock_mi_inference(interp, result.linfo)
88
return typeinf(interp, frame)
99
end
1010

@@ -774,22 +774,30 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
774774
mi = specialize_method(method, atypes, sparams)::MethodInstance
775775
code = get(code_cache(interp), mi, nothing)
776776
if code isa CodeInstance # return existing rettype if the code is already inferred
777-
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
778-
rettype = code.rettype
779-
if isdefined(code, :rettype_const)
780-
rettype_const = code.rettype_const
781-
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
782-
return PartialStruct(rettype, rettype_const), mi
783-
elseif rettype <: Core.OpaqueClosure && isa(rettype_const, PartialOpaque)
784-
return rettype_const, mi
785-
elseif isa(rettype_const, InterConditional)
786-
return rettype_const, mi
777+
if code.inferred === nothing && is_stmt_inline(get_curr_ssaflag(caller))
778+
# we already inferred this edge previously and decided to discarded the inferred code
779+
# but the inlinear will request to use it, we re-infer it here and keep it around in the local cache
780+
cache = :local
781+
else
782+
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
783+
rettype = code.rettype
784+
if isdefined(code, :rettype_const)
785+
rettype_const = code.rettype_const
786+
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
787+
return PartialStruct(rettype, rettype_const), mi
788+
elseif rettype <: Core.OpaqueClosure && isa(rettype_const, PartialOpaque)
789+
return rettype_const, mi
790+
elseif isa(rettype_const, InterConditional)
791+
return rettype_const, mi
792+
else
793+
return Const(rettype_const), mi
794+
end
787795
else
788-
return Const(rettype_const), mi
796+
return rettype, mi
789797
end
790-
else
791-
return rettype, mi
792798
end
799+
else
800+
cache = :global # cache edge targets by default
793801
end
794802
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0
795803
return Any, nothing
@@ -805,7 +813,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
805813
# completely new
806814
lock_mi_inference(interp, mi)
807815
result = InferenceResult(mi)
808-
frame = InferenceState(result, #=cached=#true, interp) # always use the cache for edge targets
816+
frame = InferenceState(result, cache, interp) # always use the cache for edge targets
809817
if frame === nothing
810818
# can't get the source for this, so we know nothing
811819
unlock_mi_inference(interp, mi)
@@ -834,14 +842,10 @@ function typeinf_code(interp::AbstractInterpreter, method::Method, @nospecialize
834842
mi = specialize_method(method, atypes, sparams)::MethodInstance
835843
ccall(:jl_typeinf_begin, Cvoid, ())
836844
result = InferenceResult(mi)
837-
frame = InferenceState(result, false, interp)
845+
frame = InferenceState(result, :no, interp)
838846
frame === nothing && return (nothing, Any)
839-
if typeinf(interp, frame) && run_optimizer
840-
opt_params = OptimizationParams(interp)
841-
result.src = src = OptimizationState(frame, opt_params, interp)
842-
optimize(interp, src, opt_params, ignorelimited(result.result))
843-
frame.src = finish!(interp, result)
844-
end
847+
run_optimizer && (frame.cached = true)
848+
typeinf(interp, frame)
845849
ccall(:jl_typeinf_end, Cvoid, ())
846850
frame.inferred || return (nothing, Any)
847851
return (frame.src, widenconst(ignorelimited(result.result)))
@@ -898,7 +902,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
898902
return retrieve_code_info(mi)
899903
end
900904
lock_mi_inference(interp, mi)
901-
frame = InferenceState(InferenceResult(mi), #=cached=#true, interp)
905+
frame = InferenceState(InferenceResult(mi), #=cache=#:global, interp)
902906
frame === nothing && return nothing
903907
typeinf(interp, frame)
904908
ccall(:jl_typeinf_end, Cvoid, ())
@@ -921,11 +925,11 @@ function typeinf_type(interp::AbstractInterpreter, method::Method, @nospecialize
921925
return code.rettype
922926
end
923927
end
924-
frame = InferenceResult(mi)
925-
typeinf(interp, frame, true)
928+
result = InferenceResult(mi)
929+
typeinf(interp, result, :global)
926930
ccall(:jl_typeinf_end, Cvoid, ())
927-
frame.result isa InferenceState && return nothing
928-
return widenconst(ignorelimited(frame.result))
931+
result.result isa InferenceState && return nothing
932+
return widenconst(ignorelimited(result.result))
929933
end
930934

931935
# This is a bridge for the C code calling `jl_typeinf_func()`
@@ -941,7 +945,7 @@ function typeinf_ext_toplevel(interp::AbstractInterpreter, linfo::MethodInstance
941945
ccall(:jl_typeinf_begin, Cvoid, ())
942946
if !src.inferred
943947
result = InferenceResult(linfo)
944-
frame = InferenceState(result, src, #=cached=#true, interp)
948+
frame = InferenceState(result, src, #=cache=#:global, interp)
945949
typeinf(interp, frame)
946950
@assert frame.inferred # TODO: deal with this better
947951
src = frame.src

test/compiler/inline.jl

+38-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ function f_ifelse(x)
173173
return b ? x + 1 : x
174174
end
175175
# 2 for now because the compiler leaves a GotoNode around
176-
@test_broken length(code_typed(f_ifelse, (String,))[1][1].code) <= 2
176+
@test length(code_typed(f_ifelse, (String,))[1][1].code) <= 2
177177

178178
# Test that inlining of _apply_iterate properly hits the inference cache
179179
@noinline cprop_inline_foo1() = (1, 1)
@@ -614,3 +614,40 @@ end
614614
# Issue #41299 - inlining deletes error check in :>
615615
g41299(f::Tf, args::Vararg{Any,N}) where {Tf,N} = f(args...)
616616
@test_throws TypeError g41299(>:, 1, 2)
617+
618+
# https://github.com/JuliaLang/julia/issues/42078
619+
# idempotency of callsite inling
620+
function getsource(mi::Core.MethodInstance)
621+
cache = Core.Compiler.code_cache(Core.Compiler.NativeInterpreter())
622+
codeinf = Core.Compiler.get(cache, mi, nothing)
623+
return isnothing(codeinf) ? nothing : codeinf.inferred
624+
end
625+
@noinline f42078(a) = sum(sincos(a))
626+
let
627+
ninlined = let
628+
code = code_typed1((Int,)) do a
629+
@inline f42078(a)
630+
end
631+
@test all(x->!isinvoke(x, :f42078), code)
632+
length(code)
633+
end
634+
635+
mi = let
636+
specs = collect(only(methods(f42078)).specializations)
637+
specs[findfirst(!isnothing, specs)]::Core.MethodInstance
638+
end
639+
640+
let # codegen will discard the source because it's not supposed to be inlined in general context
641+
a = 42
642+
f42078(a)
643+
end
644+
@assert getsource(mi) === nothing
645+
646+
let # inference should re-infer `f42078(::Int)` and we should get the same code
647+
code = code_typed1((Int,)) do a
648+
@inline f42078(a)
649+
end
650+
@test all(x->!isinvoke(x, :f42078), code)
651+
@test ninlined == length(code)
652+
end
653+
end

0 commit comments

Comments
 (0)