Skip to content

Commit 9a3673d

Browse files
aviateskdghosef
andcommitted
optimizer: supports callsite annotations of inlining, fixes #18773
Enable `@inline`/`@noinline` annotations on function callsites. From #40754. Now `@inline` and `@noinline` can be applied to a code block and then the compiler will try to (not) inline calls within the block: ```julia @inline f(...) # The compiler will try to inline `f` @inline f(...) + g(...) # The compiler will try to inline `f`, `g` and `+` @inline f(args...) = ... # Of course annotations on a definition is still allowed ``` Here are couple of notes on how those callsite annotations will work: - callsite annotation always has the precedence over the annotation applied to the definition of the called function, whichever we use `@inline`/`@noinline`: ```julia @inline function explicit_inline(args...) # body end let @noinline explicit_inline(args...) # this call will not be inlined end ``` - when callsite annotations are nested, the innermost annotations has the precedence ```julia @noinline let a0, b0 = ... a = @inline f(a0) # the compiler will try to inline this call b = notinlined(b0) # the compiler will NOT try to inline this call return a, b end ``` They're both tested and included in documentations. Co-Authored-By: Joseph Tan <[email protected]>
1 parent ed4c44f commit 9a3673d

16 files changed

+299
-71
lines changed

base/compiler/abstractinterpretation.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
554554
return nothing
555555
end
556556
mi = mi::MethodInstance
557-
if !force && !const_prop_methodinstance_heuristic(interp, method, mi)
557+
if !force && !const_prop_methodinstance_heuristic(interp, match, mi)
558558
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
559559
return nothing
560560
end
@@ -656,7 +656,8 @@ end
656656
# This is a heuristic to avoid trying to const prop through complicated functions
657657
# where we would spend a lot of time, but are probably unlikely to get an improved
658658
# result anyway.
659-
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
659+
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance)
660+
method = match.method
660661
if method.is_for_opaque_closure
661662
# Not inlining an opaque closure can be very expensive, so be generous
662663
# with the const-prop-ability. It is quite possible that we can't infer
@@ -674,7 +675,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method
674675
if isdefined(code, :inferred) && !cache_inlineable
675676
cache_inf = code.inferred
676677
if !(cache_inf === nothing)
677-
cache_inlineable = inlining_policy(interp)(cache_inf) !== nothing
678+
# TODO maybe we want to respect callsite `@inline`/`@noinline` annotations here ?
679+
cache_inlineable = inlining_policy(interp)(cache_inf, nothing, match) !== nothing
678680
end
679681
end
680682
if !cache_inlineable
@@ -1821,7 +1823,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18211823
if isa(fname, SlotNumber)
18221824
changes = StateUpdate(fname, VarState(Any, false), changes, false)
18231825
end
1824-
elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect
1826+
elseif hd === :code_coverage_effect ||
1827+
(hd !== :boundscheck && hd !== nothing && is_meta_expr_head(hd)) # :boundscheck can be narrowed to Bool
18251828
# these do not generate code
18261829
else
18271830
t = abstract_eval_statement(interp, stmt, changes, frame)

base/compiler/optimize.jl

+44-6
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,20 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T, P}
2828
policy::P
2929
end
3030

31-
function default_inlining_policy(@nospecialize(src))
31+
function default_inlining_policy(@nospecialize(src), stmt_flag::Union{Nothing,UInt8}, match::Union{MethodMatch,InferenceResult})
3232
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
3333
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
34-
src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
34+
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
3535
return src_inferred && src_inlineable ? src : nothing
36-
end
37-
if isa(src, OptimizationState) && isdefined(src, :ir)
38-
return src.src.inlineable ? src.ir : nothing
36+
elseif isa(src, OptimizationState) && isdefined(src, :ir)
37+
return (is_stmt_inline(stmt_flag) || src.src.inlineable) ? src.ir : nothing
38+
elseif src === nothing && is_stmt_inline(stmt_flag) && isa(match, MethodMatch)
39+
# when the source isn't available at this moment, try to re-infer and inline it
40+
# HACK in order to avoid cycles here, we disable inlining and makes sure the following inference never comes here
41+
# TODO sort out `AbstractInterpreter` interface to handle this well, and also inference should try to keep the source if the statement will be inlined
42+
interp = NativeInterpreter(; opt_params = OptimizationParams(; inlining = false))
43+
src, rt = typeinf_code(interp, match.method, match.spec_types, match.sparams, true)
44+
return src
3945
end
4046
return nothing
4147
end
@@ -134,6 +140,10 @@ const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError
134140
# This statement was marked as @inbounds by the user. If replaced by inlining,
135141
# any contained boundschecks may be removed
136142
const IR_FLAG_INBOUNDS = 0x01
143+
# This statement was marked as @inline by the user
144+
const IR_FLAG_INLINE = 0x01 << 1
145+
# This statement was marked as @noinline by the user
146+
const IR_FLAG_NOINLINE = 0x01 << 2
137147
# This statement may be removed if its result is unused. In particular it must
138148
# thus be both pure and effect free.
139149
const IR_FLAG_EFFECT_FREE = 0x01 << 4
@@ -179,6 +189,11 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara
179189
return inlineable
180190
end
181191

192+
is_stmt_inline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_INLINE != 0
193+
is_stmt_inline(::Nothing) = false
194+
is_stmt_noinline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_NOINLINE != 0
195+
is_stmt_noinline(::Nothing) = false # not used for now
196+
182197
# These affect control flow within the function (so may not be removed
183198
# if there is no usage within the function), but don't affect the purity
184199
# of the function as a whole.
@@ -366,6 +381,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
366381
renumber_ir_elements!(code, changemap, labelmap)
367382

368383
inbounds_depth = 0 # Number of stacked inbounds
384+
inline_flags = BitVector()
369385
meta = Any[]
370386
flags = fill(0x00, length(code))
371387
for i = 1:length(code)
@@ -380,16 +396,38 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
380396
inbounds_depth -= 1
381397
end
382398
stmt = nothing
399+
elseif isexpr(stmt, :inline)
400+
if stmt.args[1]::Bool
401+
push!(inline_flags, true)
402+
else
403+
pop!(inline_flags)
404+
end
405+
stmt = nothing
406+
elseif isexpr(stmt, :noinline)
407+
if stmt.args[1]::Bool
408+
push!(inline_flags, false)
409+
else
410+
pop!(inline_flags)
411+
end
412+
stmt = nothing
383413
else
384414
stmt = normalize(stmt, meta)
385415
end
386416
code[i] = stmt
387-
if !(stmt === nothing)
417+
if stmt !== nothing
388418
if inbounds_depth > 0
389419
flags[i] |= IR_FLAG_INBOUNDS
390420
end
421+
if !isempty(inline_flags)
422+
if last(inline_flags)
423+
flags[i] |= IR_FLAG_INLINE
424+
else
425+
flags[i] |= IR_FLAG_NOINLINE
426+
end
427+
end
391428
end
392429
end
430+
@assert isempty(inline_flags) "malformed meta flags"
393431
strip_trailing_junk!(ci, code, stmtinfo, flags)
394432
cfg = compute_basic_blocks(code)
395433
types = Any[]

base/compiler/ssair/inlining.jl

+31-38
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
600600
argexprs::Vector{Any}, atypes::Vector{Any}, arginfos::Vector{Any},
601601
arg_start::Int, istate::InliningState)
602602

603+
flag = ir.stmts[idx][:flag]
603604
new_argexprs = Any[argexprs[arg_start]]
604605
new_atypes = Any[atypes[arg_start]]
605606
# loop over original arguments and flatten any known iterators
@@ -655,8 +656,9 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
655656
info = call.info
656657
handled = false
657658
if isa(info, ConstCallInfo)
658-
if maybe_handle_const_call!(ir, state1.id, new_stmt, info, new_sig,
659-
call.rt, istate, false, todo)
659+
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
660+
ir, state1.id, new_stmt, info, new_sig,call.rt, istate, flag, false, todo)
661+
660662
handled = true
661663
else
662664
info = info.call
@@ -667,7 +669,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
667669
MethodMatchInfo[info] : info.matches
668670
# See if we can inline this call to `iterate`
669671
analyze_single_call!(ir, todo, state1.id, new_stmt,
670-
new_sig, call.rt, info, istate)
672+
new_sig, call.rt, info, istate, flag)
671673
end
672674
if i != length(thisarginfo.each)
673675
valT = getfield_tfunc(call.rt, Const(1))
@@ -716,16 +718,16 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, result::Inf
716718
return mi
717719
end
718720

719-
function resolve_todo(todo::InliningTodo, state::InliningState)
720-
spec = todo.spec::DelayedInliningSpec
721+
function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
722+
(; match) = todo.spec::DelayedInliningSpec
721723

722724
#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
723725
isconst, src = false, nothing
724-
if isa(spec.match, InferenceResult)
725-
let inferred_src = spec.match.src
726+
if isa(match, InferenceResult)
727+
let inferred_src = match.src
726728
if isa(inferred_src, Const)
727729
if !is_inlineable_constant(inferred_src.val)
728-
return compileable_specialization(state.et, spec.match)
730+
return compileable_specialization(state.et, match)
729731
end
730732
isconst, src = true, quoted(inferred_src.val)
731733
else
@@ -753,12 +755,10 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
753755
return ConstantCase(src)
754756
end
755757

756-
if src !== nothing
757-
src = state.policy(src)
758-
end
758+
src = state.policy(src, flag, match)
759759

760760
if src === nothing
761-
return compileable_specialization(et, spec.match)
761+
return compileable_specialization(et, match)
762762
end
763763

764764
if isa(src, IRCode)
@@ -769,17 +769,9 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
769769
return InliningTodo(todo.mi, src)
770770
end
771771

772-
function resolve_todo(todo::UnionSplit, state::InliningState)
772+
function resolve_todo(todo::UnionSplit, state::InliningState, flag::UInt8)
773773
UnionSplit(todo.fully_covered, todo.atype,
774-
Pair{Any,Any}[sig=>resolve_todo(item, state) for (sig, item) in todo.cases])
775-
end
776-
777-
function resolve_todo!(todo::Vector{Pair{Int, Any}}, state::InliningState)
778-
for i = 1:length(todo)
779-
idx, item = todo[i]
780-
todo[i] = idx=>resolve_todo(item, state)
781-
end
782-
todo
774+
Pair{Any,Any}[sig=>resolve_todo(item, state, flag) for (sig, item) in todo.cases])
783775
end
784776

785777
function validate_sparams(sparams::SimpleVector)
@@ -790,7 +782,7 @@ function validate_sparams(sparams::SimpleVector)
790782
end
791783

792784
function analyze_method!(match::MethodMatch, atypes::Vector{Any},
793-
state::InliningState, @nospecialize(stmttyp))
785+
state::InliningState, @nospecialize(stmttyp), flag::UInt8)
794786
method = match.method
795787
methsig = method.sig
796788

@@ -806,11 +798,9 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
806798
end
807799

808800
# Bail out if any static parameters are left as TypeVar
809-
ok = true
810801
validate_sparams(match.sparams) || return nothing
811802

812-
813-
if !state.params.inlining
803+
if !state.params.inlining || is_stmt_noinline(flag)
814804
return compileable_specialization(state.et, match)
815805
end
816806

@@ -824,7 +814,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
824814
# If we don't have caches here, delay resolving this MethodInstance
825815
# until the batch inlining step (or an external post-processing pass)
826816
state.mi_cache === nothing && return todo
827-
return resolve_todo(todo, state)
817+
return resolve_todo(todo, state, flag)
828818
end
829819

830820
function InliningTodo(mi::MethodInstance, ir::IRCode)
@@ -1050,7 +1040,7 @@ is_builtin(s::Signature) =
10501040
s.ft Builtin
10511041

10521042
function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallInfo,
1053-
state::InliningState, todo::Vector{Pair{Int, Any}})
1043+
state::InliningState, todo::Vector{Pair{Int, Any}}, flag::UInt8)
10541044
stmt = ir.stmts[idx][:inst]
10551045
calltype = ir.stmts[idx][:type]
10561046

@@ -1064,7 +1054,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallIn
10641054
atypes = atypes[4:end]
10651055
pushfirst!(atypes, atype0)
10661056

1067-
result = analyze_method!(info.match, atypes, state, calltype)
1057+
result = analyze_method!(info.match, atypes, state, calltype, flag)
10681058
handle_single_case!(ir, stmt, idx, result, true, todo)
10691059
return nothing
10701060
end
@@ -1159,7 +1149,7 @@ end
11591149

11601150
function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
11611151
sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo},
1162-
state::InliningState)
1152+
state::InliningState, flag::UInt8)
11631153
cases = Pair{Any, Any}[]
11641154
signature_union = Union{}
11651155
only_method = nothing # keep track of whether there is one matching method
@@ -1192,7 +1182,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
11921182
fully_covered = false
11931183
continue
11941184
end
1195-
case = analyze_method!(match, sig.atypes, state, calltype)
1185+
case = analyze_method!(match, sig.atypes, state, calltype, flag)
11961186
if case === nothing
11971187
fully_covered = false
11981188
continue
@@ -1219,7 +1209,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
12191209
match = meth[1]
12201210
end
12211211
fully_covered = true
1222-
case = analyze_method!(match, sig.atypes, state, calltype)
1212+
case = analyze_method!(match, sig.atypes, state, calltype, flag)
12231213
case === nothing && return
12241214
push!(cases, Pair{Any,Any}(match.spec_types, case))
12251215
end
@@ -1241,7 +1231,7 @@ end
12411231

12421232
function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
12431233
info::ConstCallInfo, sig::Signature, @nospecialize(calltype),
1244-
state::InliningState,
1234+
state::InliningState, flag::UInt8,
12451235
isinvoke::Bool, todo::Vector{Pair{Int, Any}})
12461236
# when multiple matches are found, bail out and later inliner will union-split this signature
12471237
# TODO effectively use multiple constant analysis results here
@@ -1253,7 +1243,7 @@ function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
12531243
validate_sparams(item.mi.sparam_vals) || return true
12541244
mthd_sig = item.mi.def.sig
12551245
mistypes = item.mi.specTypes
1256-
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1246+
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
12571247
if sig.atype <: mthd_sig
12581248
handle_single_case!(ir, stmt, idx, item, isinvoke, todo)
12591249
return true
@@ -1291,6 +1281,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
12911281
info = info.info
12921282
end
12931283

1284+
flag = ir.stmts[idx][:flag]
1285+
12941286
# Inference determined this couldn't be analyzed. Don't question it.
12951287
if info === false
12961288
continue
@@ -1300,23 +1292,24 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13001292
# it'll have performed a specialized analysis for just this case. Use its
13011293
# result.
13021294
if isa(info, ConstCallInfo)
1303-
if maybe_handle_const_call!(ir, idx, stmt, info, sig, calltype, state, sig.f === Core.invoke, todo)
1295+
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
1296+
ir, idx, stmt, info, sig, calltype, state, flag, sig.f === Core.invoke, todo)
13041297
continue
13051298
else
13061299
info = info.call
13071300
end
13081301
end
13091302

13101303
if isa(info, OpaqueClosureCallInfo)
1311-
result = analyze_method!(info.match, sig.atypes, state, calltype)
1304+
result = analyze_method!(info.match, sig.atypes, state, calltype, flag)
13121305
handle_single_case!(ir, stmt, idx, result, false, todo)
13131306
continue
13141307
end
13151308

13161309
# Handle invoke
13171310
if sig.f === Core.invoke
13181311
if isa(info, InvokeCallInfo)
1319-
inline_invoke!(ir, idx, sig, info, state, todo)
1312+
inline_invoke!(ir, idx, sig, info, state, todo, flag)
13201313
end
13211314
continue
13221315
end
@@ -1330,7 +1323,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13301323
continue
13311324
end
13321325

1333-
analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state)
1326+
analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state, flag)
13341327
end
13351328
todo
13361329
end

base/compiler/typeinfer.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
343343
nslots = length(ci.slotflags)
344344
resize!(ci.slottypes, nslots)
345345
resize!(ci.slotnames, nslots)
346-
return ccall(:jl_compress_ir, Any, (Any, Any), def, ci)
346+
return ccall(:jl_compress_ir, Vector{UInt8}, (Any, Any), def, ci)
347347
else
348348
return ci
349349
end

base/compiler/types.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ swapped in as long as they follow the AbstractInterpreter API.
1010
1111
All AbstractInterpreters are expected to provide at least the following methods:
1212
13-
- InferenceParams(interp) - return an `InferenceParams` instance
14-
- OptimizationParams(interp) - return an `OptimizationParams` instance
15-
- get_world_counter(interp) - return the world age for this interpreter
16-
- get_inference_cache(interp) - return the runtime inference cache
13+
- `InferenceParams(interp)` - return an `InferenceParams` instance
14+
- `OptimizationParams(interp)` - return an `OptimizationParams` instance
15+
- `get_world_counter(interp)` - return the world age for this interpreter
16+
- `get_inference_cache(interp)` - return the runtime inference cache
1717
"""
1818
abstract type AbstractInterpreter; end
1919

0 commit comments

Comments
 (0)