Skip to content

Commit 5557c2f

Browse files
committed
try to keep source if it will be force-inlined
1 parent 34e8c07 commit 5557c2f

File tree

8 files changed

+99
-53
lines changed

8 files changed

+99
-53
lines changed

base/compiler/abstractinterpretation.jl

+9-7
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,8 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
532532
mi === nothing && return nothing
533533
# try constant prop'
534534
inf_cache = get_inference_cache(interp)
535-
inf_result = cache_lookup(mi, argtypes, inf_cache)
536-
if inf_result === nothing
535+
cache = cache_lookup(mi, argtypes, inf_cache)
536+
if cache === nothing
537537
# if there might be a cycle, check to make sure we don't end up
538538
# calling ourselves here.
539539
let result = result # prevent capturing
@@ -552,8 +552,10 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
552552
frame = InferenceState(inf_result, #=cache=#false, interp)
553553
frame === nothing && return nothing # this is probably a bad generated function (unsound), but just ignore it
554554
frame.parent = sv
555-
push!(inf_cache, inf_result)
555+
push!(inf_cache, (inf_result, frame.stmt_info))
556556
typeinf(interp, frame) || return nothing
557+
else
558+
inf_result, _ = cache
557559
end
558560
result = inf_result.result
559561
# if constant inference hits a cycle, just bail out
@@ -590,7 +592,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
590592
return nothing
591593
end
592594
mi = mi::MethodInstance
593-
if !force && !const_prop_methodinstance_heuristic(interp, match, mi)
595+
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, sv)
594596
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
595597
return nothing
596598
end
@@ -692,7 +694,7 @@ end
692694
# This is a heuristic to avoid trying to const prop through complicated functions
693695
# where we would spend a lot of time, but are probably unlikely to get an improved
694696
# result anyway.
695-
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance)
697+
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance, sv::InferenceState)
696698
method = match.method
697699
if method.is_for_opaque_closure
698700
# Not inlining an opaque closure can be very expensive, so be generous
@@ -711,8 +713,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match:
711713
if isdefined(code, :inferred) && !cache_inlineable
712714
cache_inf = code.inferred
713715
if !(cache_inf === nothing)
714-
# TODO maybe we want to respect callsite `@inline`/`@noinline` annotations here ?
715-
cache_inlineable = inlining_policy(interp, cache_inf, 0x00, match) !== nothing
716+
src = inlining_policy(interp, cache_inf, get_curr_ssaflag(sv), nothing)
717+
cache_inlineable = src !== nothing
716718
end
717719
end
718720
if !cache_inlineable

base/compiler/inferenceresult.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,12 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing, va_override::
141141
return cache_argtypes, falses(length(cache_argtypes))
142142
end
143143

144-
function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{InferenceResult})
144+
function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{Tuple{InferenceResult,Vector{Any}}})
145145
method = linfo.def::Method
146146
nargs::Int = method.nargs
147147
method.isva && (nargs -= 1)
148148
length(given_argtypes) >= nargs || return nothing
149-
for cached_result in cache
149+
for (cached_result, stmt_info) in cache
150150
cached_result.linfo === linfo || continue
151151
cache_match = true
152152
cache_argtypes = cached_result.argtypes
@@ -165,7 +165,7 @@ function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache:
165165
cache_overridden_by_const[end])
166166
end
167167
cache_match || continue
168-
return cached_result
168+
return cached_result, stmt_info
169169
end
170170
return nothing
171171
end

base/compiler/inferencestate.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ mutable struct InferenceState
113113
CachedMethodTable(method_table(interp)),
114114
interp)
115115
result.result = frame
116-
cached && push!(get_inference_cache(interp), result)
116+
cached && push!(get_inference_cache(interp), (result, stmt_info))
117117
return frame
118118
end
119119
end
@@ -296,3 +296,5 @@ function print_callstack(sv::InferenceState)
296296
sv = sv.parent
297297
end
298298
end
299+
300+
get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]

base/compiler/optimize.jl

+49-17
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,57 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T, I<:AbstractInterpreter
2828
interp::I
2929
end
3030

31-
function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_flag::UInt8, match::Union{MethodMatch,InferenceResult})
31+
include("compiler/ssair/driver.jl")
32+
33+
function inlining_policy(interp::AbstractInterpreter, @nospecialize(src),
34+
stmt_flag::UInt8, todo::Union{Nothing,InliningTodo})
3235
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
3336
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
3437
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
3538
return src_inferred && src_inlineable ? src : nothing
3639
elseif isa(src, OptimizationState) && isdefined(src, :ir)
3740
return (is_stmt_inline(stmt_flag) || src.src.inlineable) ? src.ir : nothing
38-
elseif src === nothing && is_stmt_inline(stmt_flag) && isa(match, MethodMatch)
39-
# when the source isn't available at this moment, try to re-infer and inline it
40-
# NOTE we can make inference try to keep the source if the call is going to be inlined,
41-
# but then inlining will depend on local state of inference and so the first entry
42-
# and the succeeding ones may generate different code; rather we always re-infer
43-
# the source to avoid the problem while it's obviously not most efficient
44-
# HACK disable inlining for the re-inference to avoid cycles by making sure the following inference never comes here again
45-
interp = NativeInterpreter(get_world_counter(interp); opt_params = OptimizationParams(; inlining = false))
46-
src, rt = typeinf_code(interp, match.method, match.spec_types, match.sparams, true)
47-
return src
41+
elseif src === nothing && todo !== nothing && is_stmt_inline(stmt_flag)
42+
# if this statement is forced to be inlined, try additional effort to find the source
43+
# in the local cache, and if found optimize and inline it
44+
mi = todo.mi
45+
(; match, atypes, stmttype) = todo.spec::DelayedInliningSpec
46+
if isa(match, MethodMatch)
47+
cache = cache_lookup(mi, atypes, get_inference_cache(interp))
48+
cache === nothing && return nothing
49+
inf_result, stmt_info = cache
50+
else
51+
local cache = nothing
52+
for (inf_result, stmt_info) in get_inference_cache(interp)
53+
if inf_result === match
54+
cache = inf_result, stmt_info
55+
break
56+
end
57+
end
58+
cache === nothing && return nothing
59+
inf_result, stmt_info = cache
60+
end
61+
src = inf_result.src
62+
if isa(src, CodeInfo)
63+
elseif isa(src, OptimizationState)
64+
src = src.src
65+
else
66+
return nothing
67+
end
68+
# HACK disable inlining for this optimization, otherwise we're likely to come back to here again
69+
params = OptimizationParams(interp)
70+
newparams = OptimizationParams(; inlining = false,
71+
max_methods = params.MAX_METHODS,
72+
tuple_splat = params.MAX_TUPLE_SPLAT,
73+
union_splitting = params.MAX_UNION_SPLITTING,
74+
unoptimize_throw_blocks = params.unoptimize_throw_blocks)
75+
opt = OptimizationState(mi, copy(src), newparams, interp; stmt_info)
76+
optimize(interp, opt, newparams, stmttype)
77+
return opt.ir
4878
end
4979
return nothing
5080
end
5181

52-
include("compiler/ssair/driver.jl")
53-
5482
mutable struct OptimizationState
5583
linfo::MethodInstance
5684
src::CodeInfo
@@ -72,7 +100,8 @@ mutable struct OptimizationState
72100
frame.sptypes, frame.slottypes, false,
73101
inlining)
74102
end
75-
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
103+
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter;
104+
stmt_info::Union{Nothing,Vector{Any}} = nothing)
76105
# prepare src for running optimization passes
77106
# if it isn't already
78107
nssavalues = src.ssavaluetypes
@@ -86,7 +115,9 @@ mutable struct OptimizationState
86115
if slottypes === nothing
87116
slottypes = Any[ Any for i = 1:nslots ]
88117
end
89-
stmt_info = Any[nothing for i = 1:nssavalues]
118+
if stmt_info === nothing
119+
stmt_info = Any[nothing for i = 1:nssavalues]
120+
end
90121
# cache some useful state computations
91122
def = linfo.def
92123
mod = isa(def, Method) ? def.module : def
@@ -103,10 +134,11 @@ mutable struct OptimizationState
103134
end
104135
end
105136

106-
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
137+
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter;
138+
stmt_info::Union{Nothing,Vector{Any}} = nothing)
107139
src = retrieve_code_info(linfo)
108140
src === nothing && return nothing
109-
return OptimizationState(linfo, src, params, interp)
141+
return OptimizationState(linfo, src, params, interp; stmt_info)
110142
end
111143

112144
function ir_to_codeinf!(opt::OptimizationState)

base/compiler/ssair/inlining.jl

+7-6
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,8 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, (; linfo)::
722722
end
723723

724724
function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
725-
(; match) = todo.spec::DelayedInliningSpec
725+
mi = todo.mi
726+
(; match, atypes) = todo.spec::DelayedInliningSpec
726727

727728
#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
728729
isconst, src = false, nothing
@@ -737,7 +738,7 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
737738
isconst, src = false, inferred_src
738739
end
739740
else
740-
linfo = get(state.mi_cache, todo.mi, nothing)
741+
linfo = get(state.mi_cache, mi, nothing)
741742
if linfo isa CodeInstance
742743
if invoke_api(linfo) == 2
743744
# in this case function can be inlined to a constant
@@ -753,11 +754,11 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
753754
et = state.et
754755

755756
if isconst && et !== nothing
756-
push!(et, todo.mi)
757+
push!(et, mi)
757758
return ConstantCase(src)
758759
end
759760

760-
src = inlining_policy(state.interp, src, flag, match)
761+
src = inlining_policy(state.interp, src, flag, todo)
761762

762763
if src === nothing
763764
return compileable_specialization(et, match)
@@ -767,8 +768,8 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
767768
src = copy(src)
768769
end
769770

770-
et !== nothing && push!(et, todo.mi)
771-
return InliningTodo(todo.mi, src)
771+
et !== nothing && push!(et, mi)
772+
return InliningTodo(mi, src)
772773
end
773774

774775
function resolve_todo(todo::UnionSplit, state::InliningState, flag::UInt8)

base/compiler/typeinfer.jl

+13-13
Original file line numberDiff line numberDiff line change
@@ -438,34 +438,39 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
438438
# inspect whether our inference had a limited result accuracy,
439439
# else it may be suitable to cache
440440
me.bestguess = cycle_fix_limited(me.bestguess, me)
441+
parent = me.parent
441442
limited_ret = me.bestguess isa LimitedAccuracy
442443
limited_src = false
443444
if !limited_ret
444445
gt = me.src.ssavaluetypes::Vector{Any}
445446
for j = 1:length(gt)
446447
gt[j] = gtj = cycle_fix_limited(gt[j], me)
447-
if gtj isa LimitedAccuracy && me.parent !== nothing
448+
if gtj isa LimitedAccuracy && parent !== nothing
448449
limited_src = true
449450
break
450451
end
451452
end
452453
end
453454
if limited_ret
454455
# a parent may be cached still, but not this intermediate work:
455-
# we can throw everything else away now
456-
me.result.src = nothing
456+
# we can throw everything else away now, unless inlinear will still want to have the inferred source
457+
if !(parent !== nothing && is_stmt_inline(get_curr_ssaflag(parent)))
458+
me.result.src = nothing
459+
end
457460
me.cached = false
458461
me.src.inlineable = false
459462
unlock_mi_inference(interp, me.linfo)
460463
elseif limited_src
461464
# a type result will be cached still, but not this intermediate work:
462-
# we can throw everything else away now
463-
me.result.src = nothing
465+
# we can throw everything else away now, unless inlinear will still want to have the inferred source
466+
if !(parent !== nothing && is_stmt_inline(get_curr_ssaflag(parent)))
467+
me.result.src = nothing
468+
end
464469
me.src.inlineable = false
465470
else
466471
# annotate fulltree with type information,
467472
# either because we are the outermost code, or we might use this later
468-
doopt = (me.cached || me.parent !== nothing)
473+
doopt = (me.cached || parent !== nothing)
469474
type_annotate!(me, doopt)
470475
if doopt && may_optimize(interp)
471476
me.result.src = OptimizationState(me, OptimizationParams(interp), interp)
@@ -834,14 +839,9 @@ function typeinf_code(interp::AbstractInterpreter, method::Method, @nospecialize
834839
mi = specialize_method(method, atypes, sparams)::MethodInstance
835840
ccall(:jl_typeinf_begin, Cvoid, ())
836841
result = InferenceResult(mi)
837-
frame = InferenceState(result, false, interp)
842+
frame = InferenceState(result, run_optimizer, interp)
838843
frame === nothing && return (nothing, Any)
839-
if typeinf(interp, frame) && run_optimizer
840-
opt_params = OptimizationParams(interp)
841-
result.src = src = OptimizationState(frame, opt_params, interp)
842-
optimize(interp, src, opt_params, ignorelimited(result.result))
843-
frame.src = finish!(interp, result)
844-
end
844+
typeinf(interp, frame)
845845
ccall(:jl_typeinf_end, Cvoid, ())
846846
frame.inferred || return (nothing, Any)
847847
return (frame.src, widenconst(ignorelimited(result.result)))

base/compiler/types.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ It contains many parameters used by the compilation pipeline.
143143
"""
144144
struct NativeInterpreter <: AbstractInterpreter
145145
# Cache of inference results for this particular interpreter
146-
cache::Vector{InferenceResult}
146+
cache::Vector{Tuple{InferenceResult,Vector{Any}}}
147147
# The world age we're working inside of
148148
world::UInt
149149

@@ -168,7 +168,7 @@ struct NativeInterpreter <: AbstractInterpreter
168168

169169
return new(
170170
# Initially empty cache
171-
Vector{InferenceResult}(),
171+
Tuple{InferenceResult,Vector{Any}}[],
172172

173173
# world age counter
174174
world,

test/compiler/inline.jl

+13-4
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ function f_ifelse(x)
173173
return b ? x + 1 : x
174174
end
175175
# 2 for now because the compiler leaves a GotoNode around
176-
@test_broken length(code_typed(f_ifelse, (String,))[1][1].code) <= 2
176+
@test length(code_typed(f_ifelse, (String,))[1][1].code) <= 2
177177

178178
# Test that inlining of _apply_iterate properly hits the inference cache
179179
@noinline cprop_inline_foo1() = (1, 1)
@@ -543,9 +543,17 @@ end
543543

544544
import Core.Compiler: isType
545545

546-
limited(a) = @noinline(isType(a)) ? @inline(limited(a.parameters[1])) : rand(a)
546+
function limited(a)
547+
@nospecialize a
548+
if @noinline(isType(a))
549+
return @inline(limited(a.parameters[1]))
550+
else
551+
return rand(a)
552+
end
553+
end
547554

548555
function multilimited(a)
556+
@nospecialize a
549557
if @noinline(isType(a))
550558
return @inline(multilimited(a.parameters[1]))
551559
else
@@ -602,12 +610,13 @@ end
602610
end
603611

604612
let code = code_typed1(m.limited, (Any,))
605-
@test count(x->isinvoke(x, :isType), code) == 2
613+
@test count(x->isinvoke(x, :isType), code) == 2 # caller + inlined callee
606614
end
607-
# check that inlining for recursive callsites doesn't depend on inference local cache
608615
let code1 = code_typed1(m.multilimited, (Any,))
609616
code2 = code_typed1(m.multilimited, (Any,))
617+
# check that inlining for recursive callsites doesn't depend on inference local cache
610618
@test code1 == code2
619+
@test count(x->isinvoke(x, :isType), code1) == 3 # caller + inlined callee + inlined callee
611620
end
612621
end
613622

0 commit comments

Comments
 (0)