Skip to content

Commit f4eb5d9

Browse files
committed
Auto merge of #68828 - oli-obk:inline_cycle, r=wesleywiser
Prevent query cycles in the MIR inliner r? `@eddyb` `@wesleywiser` cc `@rust-lang/wg-mir-opt` The general design is that we have a new query that is run on the `validated_mir` instead of on the `optimized_mir`. That query is forced before going into the optimization pipeline, so as to not try to read from a stolen MIR. The query should not be cached cross crate, as you should never call it for items from other crates. By its very design calls into other crates can never cause query cycles. This is a pessimistic approach to inlining, since we strictly have more calls in the `validated_mir` than we have in `optimized_mir`, but that's not a problem imo.
2 parents 7fba12b + d38553c commit f4eb5d9

File tree

15 files changed

+484
-18
lines changed

15 files changed

+484
-18
lines changed

compiler/rustc_middle/src/query/mod.rs

+21
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,27 @@ rustc_queries! {
782782
}
783783

784784
Other {
785+
/// Check whether the function has any recursion that could cause the inliner to trigger
786+
/// a cycle. Returns the call stack causing the cycle. The call stack does not contain the
787+
/// current function, just all intermediate functions.
788+
query mir_callgraph_reachable(key: (ty::Instance<'tcx>, LocalDefId)) -> bool {
789+
fatal_cycle
790+
desc { |tcx|
791+
"computing if `{}` (transitively) calls `{}`",
792+
key.0,
793+
tcx.def_path_str(key.1.to_def_id()),
794+
}
795+
}
796+
797+
/// Obtain all the calls into other local functions
798+
query mir_inliner_callees(key: ty::InstanceDef<'tcx>) -> &'tcx [(DefId, SubstsRef<'tcx>)] {
799+
fatal_cycle
800+
desc { |tcx|
801+
"computing all local function calls in `{}`",
802+
tcx.def_path_str(key.def_id()),
803+
}
804+
}
805+
785806
/// Evaluates a constant and returns the computed allocation.
786807
///
787808
/// **Do not use this** directly, use the `tcx.eval_static_initializer` wrapper.

compiler/rustc_middle/src/ty/query/keys.rs

+11
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ impl Key for (DefId, DefId) {
127127
}
128128
}
129129

130+
impl Key for (ty::Instance<'tcx>, LocalDefId) {
131+
type CacheSelector = DefaultCacheSelector;
132+
133+
fn query_crate(&self) -> CrateNum {
134+
self.0.query_crate()
135+
}
136+
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
137+
self.0.default_span(tcx)
138+
}
139+
}
140+
130141
impl Key for (DefId, LocalDefId) {
131142
type CacheSelector = DefaultCacheSelector;
132143

compiler/rustc_mir/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ pub fn provide(providers: &mut Providers) {
5757
providers.eval_to_const_value_raw = const_eval::eval_to_const_value_raw_provider;
5858
providers.eval_to_allocation_raw = const_eval::eval_to_allocation_raw_provider;
5959
providers.const_caller_location = const_eval::const_caller_location;
60+
providers.mir_callgraph_reachable = transform::inline::cycle::mir_callgraph_reachable;
61+
providers.mir_inliner_callees = transform::inline::cycle::mir_inliner_callees;
6062
providers.destructure_const = |tcx, param_env_and_value| {
6163
let (param_env, value) = param_env_and_value.into_parts();
6264
const_eval::destructure_const(tcx, param_env, value)

compiler/rustc_mir/src/transform/const_prop.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,15 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
440440
}
441441

442442
fn lint_root(&self, source_info: SourceInfo) -> Option<HirId> {
443-
match &self.source_scopes[source_info.scope].local_data {
443+
let mut data = &self.source_scopes[source_info.scope];
444+
// FIXME(oli-obk): we should be able to just walk the `inlined_parent_scope`, but it
445+
// does not work as I thought it would. Needs more investigation and documentation.
446+
while data.inlined.is_some() {
447+
trace!(?data);
448+
data = &self.source_scopes[data.parent_scope.unwrap()];
449+
}
450+
trace!(?data);
451+
match &data.local_data {
444452
ClearCrossCrate::Set(data) => Some(data.lint_root),
445453
ClearCrossCrate::Clear => None,
446454
}

compiler/rustc_mir/src/transform/inline.rs

+74-17
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use crate::transform::MirPass;
1717
use std::iter;
1818
use std::ops::{Range, RangeFrom};
1919

20+
crate mod cycle;
21+
2022
const INSTR_COST: usize = 5;
2123
const CALL_PENALTY: usize = 25;
2224
const LANDINGPAD_PENALTY: usize = 50;
@@ -37,6 +39,9 @@ struct CallSite<'tcx> {
3739

3840
impl<'tcx> MirPass<'tcx> for Inline {
3941
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
42+
// If you change this optimization level, also change the level in
43+
// `mir_drops_elaborated_and_const_checked` for the call to `mir_inliner_callees`.
44+
// Otherwise you will get an ICE about stolen MIR.
4045
if tcx.sess.opts.debugging_opts.mir_opt_level < 2 {
4146
return;
4247
}
@@ -50,6 +55,8 @@ impl<'tcx> MirPass<'tcx> for Inline {
5055
return;
5156
}
5257

58+
let span = trace_span!("inline", body = %tcx.def_path_str(body.source.def_id()));
59+
let _guard = span.enter();
5360
if inline(tcx, body) {
5461
debug!("running simplify cfg on {:?}", body.source);
5562
CfgSimplifier::new(body).simplify();
@@ -90,8 +97,8 @@ struct Inliner<'tcx> {
9097
codegen_fn_attrs: &'tcx CodegenFnAttrs,
9198
/// Caller HirID.
9299
hir_id: hir::HirId,
93-
/// Stack of inlined instances.
94-
history: Vec<Instance<'tcx>>,
100+
/// Stack of inlined Instances.
101+
history: Vec<ty::Instance<'tcx>>,
95102
/// Indicates that the caller body has been modified.
96103
changed: bool,
97104
}
@@ -103,13 +110,28 @@ impl Inliner<'tcx> {
103110
None => continue,
104111
Some(it) => it,
105112
};
113+
let span = trace_span!("process_blocks", %callsite.callee, ?bb);
114+
let _guard = span.enter();
115+
116+
trace!(
117+
"checking for self recursion ({:?} vs body_source: {:?})",
118+
callsite.callee.def_id(),
119+
caller_body.source.def_id()
120+
);
121+
if callsite.callee.def_id() == caller_body.source.def_id() {
122+
debug!("Not inlining a function into itself");
123+
continue;
124+
}
106125

107-
if !self.is_mir_available(&callsite.callee, caller_body) {
126+
if !self.is_mir_available(callsite.callee, caller_body) {
108127
debug!("MIR unavailable {}", callsite.callee);
109128
continue;
110129
}
111130

131+
let span = trace_span!("instance_mir", %callsite.callee);
132+
let instance_mir_guard = span.enter();
112133
let callee_body = self.tcx.instance_mir(callsite.callee.def);
134+
drop(instance_mir_guard);
113135
if !self.should_inline(callsite, callee_body) {
114136
continue;
115137
}
@@ -137,28 +159,61 @@ impl Inliner<'tcx> {
137159
}
138160
}
139161

140-
fn is_mir_available(&self, callee: &Instance<'tcx>, caller_body: &Body<'tcx>) -> bool {
141-
if let InstanceDef::Item(_) = callee.def {
142-
if !self.tcx.is_mir_available(callee.def_id()) {
143-
return false;
162+
#[instrument(skip(self, caller_body))]
163+
fn is_mir_available(&self, callee: Instance<'tcx>, caller_body: &Body<'tcx>) -> bool {
164+
match callee.def {
165+
InstanceDef::Item(_) => {
166+
// If there is no MIR available (either because it was not in metadata or
167+
// because it has no MIR because it's an extern function), then the inliner
168+
// won't cause cycles on this.
169+
if !self.tcx.is_mir_available(callee.def_id()) {
170+
return false;
171+
}
144172
}
173+
// These have no own callable MIR.
174+
InstanceDef::Intrinsic(_) | InstanceDef::Virtual(..) => return false,
175+
// This cannot result in an immediate cycle since the callee MIR is a shim, which does
176+
// not get any optimizations run on it. Any subsequent inlining may cause cycles, but we
177+
// do not need to catch this here, we can wait until the inliner decides to continue
178+
// inlining a second time.
179+
InstanceDef::VtableShim(_)
180+
| InstanceDef::ReifyShim(_)
181+
| InstanceDef::FnPtrShim(..)
182+
| InstanceDef::ClosureOnceShim { .. }
183+
| InstanceDef::DropGlue(..)
184+
| InstanceDef::CloneShim(..) => return true,
185+
}
186+
187+
if self.tcx.is_constructor(callee.def_id()) {
188+
trace!("constructors always have MIR");
189+
// Constructor functions cannot cause a query cycle.
190+
return true;
145191
}
146192

147193
if let Some(callee_def_id) = callee.def_id().as_local() {
148194
let callee_hir_id = self.tcx.hir().local_def_id_to_hir_id(callee_def_id);
149-
// Avoid a cycle here by only using `instance_mir` only if we have
150-
// a lower `HirId` than the callee. This ensures that the callee will
151-
// not inline us. This trick only works without incremental compilation.
152-
// So don't do it if that is enabled. Also avoid inlining into generators,
195+
// Avoid inlining into generators,
153196
// since their `optimized_mir` is used for layout computation, which can
154197
// create a cycle, even when no attempt is made to inline the function
155198
// in the other direction.
156-
!self.tcx.dep_graph.is_fully_enabled()
199+
caller_body.generator_kind.is_none()
200+
&& (
201+
// Avoid a cycle here by only using `instance_mir` only if we have
202+
// a lower `HirId` than the callee. This ensures that the callee will
203+
// not inline us. This trick only works without incremental compilation.
204+
// So don't do it if that is enabled.
205+
!self.tcx.dep_graph.is_fully_enabled()
157206
&& self.hir_id < callee_hir_id
158-
&& caller_body.generator_kind.is_none()
207+
// If we know for sure that the function we're calling will itself try to
208+
// call us, then we avoid inlining that function.
209+
|| !self.tcx.mir_callgraph_reachable((callee, caller_body.source.def_id().expect_local()))
210+
)
159211
} else {
160-
// This cannot result in a cycle since the callee MIR is from another crate
161-
// and is already optimized.
212+
// This cannot result in an immediate cycle since the callee MIR is from another crate
213+
// and is already optimized. Any subsequent inlining may cause cycles, but we do
214+
// not need to catch this here, we can wait until the inliner decides to continue
215+
// inlining a second time.
216+
trace!("functions from other crates always have MIR");
162217
true
163218
}
164219
}
@@ -203,8 +258,8 @@ impl Inliner<'tcx> {
203258
None
204259
}
205260

261+
#[instrument(skip(self, callee_body))]
206262
fn should_inline(&self, callsite: CallSite<'tcx>, callee_body: &Body<'tcx>) -> bool {
207-
debug!("should_inline({:?})", callsite);
208263
let tcx = self.tcx;
209264

210265
if callsite.fn_sig.c_variadic() {
@@ -333,7 +388,9 @@ impl Inliner<'tcx> {
333388
if let Ok(Some(instance)) =
334389
Instance::resolve(self.tcx, self.param_env, def_id, substs)
335390
{
336-
if callsite.callee == instance || self.history.contains(&instance) {
391+
if callsite.callee.def_id() == instance.def_id()
392+
|| self.history.contains(&instance)
393+
{
337394
debug!("`callee is recursive - not inlining");
338395
return false;
339396
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
2+
use rustc_data_structures::stack::ensure_sufficient_stack;
3+
use rustc_hir::def_id::{DefId, LocalDefId};
4+
use rustc_middle::mir::TerminatorKind;
5+
use rustc_middle::ty::TypeFoldable;
6+
use rustc_middle::ty::{self, subst::SubstsRef, InstanceDef, TyCtxt};
7+
8+
// FIXME: check whether it is cheaper to precompute the entire call graph instead of invoking
9+
// this query riddiculously often.
10+
#[instrument(skip(tcx, root, target))]
11+
crate fn mir_callgraph_reachable(
12+
tcx: TyCtxt<'tcx>,
13+
(root, target): (ty::Instance<'tcx>, LocalDefId),
14+
) -> bool {
15+
trace!(%root, target = %tcx.def_path_str(target.to_def_id()));
16+
let param_env = tcx.param_env_reveal_all_normalized(target);
17+
assert_ne!(
18+
root.def_id().expect_local(),
19+
target,
20+
"you should not call `mir_callgraph_reachable` on immediate self recursion"
21+
);
22+
assert!(
23+
matches!(root.def, InstanceDef::Item(_)),
24+
"you should not call `mir_callgraph_reachable` on shims"
25+
);
26+
assert!(
27+
!tcx.is_constructor(root.def_id()),
28+
"you should not call `mir_callgraph_reachable` on enum/struct constructor functions"
29+
);
30+
#[instrument(skip(tcx, param_env, target, stack, seen, recursion_limiter, caller))]
31+
fn process(
32+
tcx: TyCtxt<'tcx>,
33+
param_env: ty::ParamEnv<'tcx>,
34+
caller: ty::Instance<'tcx>,
35+
target: LocalDefId,
36+
stack: &mut Vec<ty::Instance<'tcx>>,
37+
seen: &mut FxHashSet<ty::Instance<'tcx>>,
38+
recursion_limiter: &mut FxHashMap<DefId, usize>,
39+
) -> bool {
40+
trace!(%caller);
41+
for &(callee, substs) in tcx.mir_inliner_callees(caller.def) {
42+
let substs = caller.subst_mir_and_normalize_erasing_regions(tcx, param_env, substs);
43+
let callee = match ty::Instance::resolve(tcx, param_env, callee, substs).unwrap() {
44+
Some(callee) => callee,
45+
None => {
46+
trace!(?callee, "cannot resolve, skipping");
47+
continue;
48+
}
49+
};
50+
51+
// Found a path.
52+
if callee.def_id() == target.to_def_id() {
53+
return true;
54+
}
55+
56+
if tcx.is_constructor(callee.def_id()) {
57+
trace!("constructors always have MIR");
58+
// Constructor functions cannot cause a query cycle.
59+
continue;
60+
}
61+
62+
match callee.def {
63+
InstanceDef::Item(_) => {
64+
// If there is no MIR available (either because it was not in metadata or
65+
// because it has no MIR because it's an extern function), then the inliner
66+
// won't cause cycles on this.
67+
if !tcx.is_mir_available(callee.def_id()) {
68+
trace!(?callee, "no mir available, skipping");
69+
continue;
70+
}
71+
}
72+
// These have no own callable MIR.
73+
InstanceDef::Intrinsic(_) | InstanceDef::Virtual(..) => continue,
74+
// These have MIR and if that MIR is inlined, substituted and then inlining is run
75+
// again, a function item can end up getting inlined. Thus we'll be able to cause
76+
// a cycle that way
77+
InstanceDef::VtableShim(_)
78+
| InstanceDef::ReifyShim(_)
79+
| InstanceDef::FnPtrShim(..)
80+
| InstanceDef::ClosureOnceShim { .. }
81+
| InstanceDef::CloneShim(..) => {}
82+
InstanceDef::DropGlue(..) => {
83+
// FIXME: A not fully substituted drop shim can cause ICEs if one attempts to
84+
// have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this
85+
// needs some more analysis.
86+
if callee.needs_subst() {
87+
continue;
88+
}
89+
}
90+
}
91+
92+
if seen.insert(callee) {
93+
let recursion = recursion_limiter.entry(callee.def_id()).or_default();
94+
trace!(?callee, recursion = *recursion);
95+
if tcx.sess.recursion_limit().value_within_limit(*recursion) {
96+
*recursion += 1;
97+
stack.push(callee);
98+
let found_recursion = ensure_sufficient_stack(|| {
99+
process(tcx, param_env, callee, target, stack, seen, recursion_limiter)
100+
});
101+
if found_recursion {
102+
return true;
103+
}
104+
stack.pop();
105+
} else {
106+
// Pessimistically assume that there could be recursion.
107+
return true;
108+
}
109+
}
110+
}
111+
false
112+
}
113+
process(
114+
tcx,
115+
param_env,
116+
root,
117+
target,
118+
&mut Vec::new(),
119+
&mut FxHashSet::default(),
120+
&mut FxHashMap::default(),
121+
)
122+
}
123+
124+
crate fn mir_inliner_callees<'tcx>(
125+
tcx: TyCtxt<'tcx>,
126+
instance: ty::InstanceDef<'tcx>,
127+
) -> &'tcx [(DefId, SubstsRef<'tcx>)] {
128+
let steal;
129+
let guard;
130+
let body = match (instance, instance.def_id().as_local()) {
131+
(InstanceDef::Item(_), Some(def_id)) => {
132+
let def = ty::WithOptConstParam::unknown(def_id);
133+
steal = tcx.mir_promoted(def).0;
134+
guard = steal.borrow();
135+
&*guard
136+
}
137+
// Functions from other crates and MIR shims
138+
_ => tcx.instance_mir(instance),
139+
};
140+
let mut calls = Vec::new();
141+
for bb_data in body.basic_blocks() {
142+
let terminator = bb_data.terminator();
143+
if let TerminatorKind::Call { func, .. } = &terminator.kind {
144+
let ty = func.ty(&body.local_decls, tcx);
145+
let call = match ty.kind() {
146+
ty::FnDef(def_id, substs) => (*def_id, *substs),
147+
_ => continue,
148+
};
149+
// We've seen this before
150+
if calls.contains(&call) {
151+
continue;
152+
}
153+
calls.push(call);
154+
}
155+
}
156+
tcx.arena.alloc_slice(&calls)
157+
}

0 commit comments

Comments
 (0)