Skip to content

Commit 3824aa9

Browse files
committed
Auto merge of rust-lang#121557 - RalfJung:const-fn-call-promotion, r=<try>
restrict promotion of `const fn` calls We only promote them in `const`/`static` initializers, but even that is still unfortunate -- we still cannot add promoteds to required_consts. But we should add them there to make sure it's always okay to evaluate every const we encounter in a MIR body. That effort of not promoting things that can fail to evaluate is tracked in rust-lang#80619. These `const fn` calls are the last missing piece. So I propose that we do not promote const-fn calls in const when that may fail without the entire const failing, thereby completing rust-lang#80619. Unfortunately we can't just reject promoting these functions outright due to backwards compatibility. So let's see if we can find a hack that makes crater happy... For the record, this is the [crater analysis](rust-lang#80243 (comment)) from when I tried to entirely forbid this kind of promotion. It's a tiny amount of breakage and if we had a nice alternative for code like that, we could conceivably push it through... but sadly, inline const expressions are still blocked on t-lang concerns about post-monomorphization errors and we haven't yet figured out an implementation that can resolve those concerns. So we're forced to make progress via other means, such as terrible hacks like this. Attempt one: only promote calls on the "safe path" at the beginning of a MIR block. This is the path that starts at the start block and continues via gotos and calls, but stops at the first branch. If we had imposed this restriction before stabilizing `if` and `match` in `const`, this would have definitely been sufficient... EDIT: Turns out that works. :) **Here's the t-lang [nomination comment](rust-lang#121557 (comment) r? `@oli-obk`
2 parents c2fbe40 + 8a1214a commit 3824aa9

17 files changed

+235
-285
lines changed

compiler/rustc_const_eval/src/const_eval/machine.rs

+14
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,20 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for CompileTimeInterpreter<'mir,
375375

376376
const PANIC_ON_ALLOC_FAIL: bool = false; // will be raised as a proper error
377377

378+
#[inline]
379+
fn all_required_consts_are_checked(ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
380+
// Generally we expect required_consts to be properly filled, except for promoteds where
381+
// storing these consts shows up negatively in benchmarks. A promoted can only be relevant
382+
// if its parent MIR is relevant, and the consts in the promoted will be in the parent's
383+
// `required_consts`, so we are still sure to catch any const-eval bugs, just a bit less
384+
// directly.
385+
if ecx.frame_idx() == 0 && ecx.frame().body.source.promoted.is_some() {
386+
false
387+
} else {
388+
true
389+
}
390+
}
391+
378392
#[inline(always)]
379393
fn enforce_alignment(ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
380394
matches!(ecx.machine.check_alignment, CheckAlignment::Error)

compiler/rustc_const_eval/src/interpret/eval_context.rs

+13-11
Original file line numberDiff line numberDiff line change
@@ -796,15 +796,13 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
796796
self.stack_mut().push(frame);
797797

798798
// Make sure all the constants required by this frame evaluate successfully (post-monomorphization check).
799-
if M::POST_MONO_CHECKS {
800-
for &const_ in &body.required_consts {
801-
let c = self
802-
.instantiate_from_current_frame_and_normalize_erasing_regions(const_.const_)?;
803-
c.eval(*self.tcx, self.param_env, Some(const_.span)).map_err(|err| {
804-
err.emit_note(*self.tcx);
805-
err
806-
})?;
807-
}
799+
for &const_ in &body.required_consts {
800+
let c =
801+
self.instantiate_from_current_frame_and_normalize_erasing_regions(const_.const_)?;
802+
c.eval(*self.tcx, self.param_env, Some(const_.span)).map_err(|err| {
803+
err.emit_note(*self.tcx);
804+
err
805+
})?;
808806
}
809807

810808
// done
@@ -1153,8 +1151,12 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
11531151
) -> InterpResult<'tcx, OpTy<'tcx, M::Provenance>> {
11541152
M::eval_mir_constant(self, *val, span, layout, |ecx, val, span, layout| {
11551153
let const_val = val.eval(*ecx.tcx, ecx.param_env, span).map_err(|err| {
1156-
// FIXME: somehow this is reachable even when POST_MONO_CHECKS is on.
1157-
// Are we not always populating `required_consts`?
1154+
if M::all_required_consts_are_checked(self)
1155+
&& !matches!(err, ErrorHandled::TooGeneric(..))
1156+
{
1157+
// Looks like the const is not captued by `required_consts`, that's bad.
1158+
bug!("interpret const eval failure of {val:?} which is not in required_consts");
1159+
}
11581160
err.emit_note(*ecx.tcx);
11591161
err
11601162
})?;

compiler/rustc_const_eval/src/interpret/machine.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
140140
/// Should the machine panic on allocation failures?
141141
const PANIC_ON_ALLOC_FAIL: bool;
142142

143-
/// Should post-monomorphization checks be run when a stack frame is pushed?
144-
const POST_MONO_CHECKS: bool = true;
143+
/// Determines whether `eval_mir_constant` can never fail because all required consts have
144+
/// already been checked before.
145+
fn all_required_consts_are_checked(ecx: &InterpCx<'mir, 'tcx, Self>) -> bool;
145146

146147
/// Whether memory accesses should be alignment-checked.
147148
fn enforce_alignment(ecx: &InterpCx<'mir, 'tcx, Self>) -> bool;

compiler/rustc_mir_transform/src/dataflow_const_prop.rs

+6
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,12 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
900900
type MemoryKind = !;
901901
const PANIC_ON_ALLOC_FAIL: bool = true;
902902

903+
#[inline(always)]
904+
fn all_required_consts_are_checked(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
905+
// We want to just eval random consts in the program, so `eval_mir_const` can fail.
906+
false
907+
}
908+
903909
#[inline(always)]
904910
fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
905911
false // no reason to enforce alignment

compiler/rustc_mir_transform/src/inline.rs

+1-11
Original file line numberDiff line numberDiff line change
@@ -706,17 +706,7 @@ impl<'tcx> Inliner<'tcx> {
706706
});
707707

708708
// Copy only unevaluated constants from the callee_body into the caller_body.
709-
// Although we are only pushing `ConstKind::Unevaluated` consts to
710-
// `required_consts`, here we may not only have `ConstKind::Unevaluated`
711-
// because we are calling `instantiate_and_normalize_erasing_regions`.
712-
caller_body.required_consts.extend(callee_body.required_consts.iter().copied().filter(
713-
|&ct| match ct.const_ {
714-
Const::Ty(_) => {
715-
bug!("should never encounter ty::UnevaluatedConst in `required_consts`")
716-
}
717-
Const::Val(..) | Const::Unevaluated(..) => true,
718-
},
719-
));
709+
caller_body.required_consts.extend(callee_body.required_consts);
720710
}
721711

722712
fn make_call_args(

compiler/rustc_mir_transform/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ fn mir_promoted(
343343
body.tainted_by_errors = Some(error_reported);
344344
}
345345

346+
// Collect `required_consts` *before* promotion, so if there are any consts being promoted
347+
// we still add them to the list in the outer MIR body.
346348
let mut required_consts = Vec::new();
347349
let mut required_consts_visitor = RequiredConstsVisitor::new(&mut required_consts);
348350
for (bb, bb_data) in traversal::reverse_postorder(&body) {

compiler/rustc_mir_transform/src/promote_consts.rs

+113-34
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//! move analysis runs after promotion on broken MIR.
1414
1515
use either::{Left, Right};
16+
use rustc_data_structures::fx::FxHashSet;
1617
use rustc_hir as hir;
1718
use rustc_middle::mir;
1819
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
@@ -175,6 +176,12 @@ fn collect_temps_and_candidates<'tcx>(
175176
struct Validator<'a, 'tcx> {
176177
ccx: &'a ConstCx<'a, 'tcx>,
177178
temps: &'a mut IndexSlice<Local, TempState>,
179+
/// For backwards compatibility, we are promoting function calls in `const`/`static`
180+
/// initializers. But we want to avoid evaluating code that might panic and that otherwise would
181+
/// not have been evaluated, so we only promote such calls in basic blocks that are guaranteed
182+
/// to execute. In other words, we only promote such calls in basic blocks that are definitely
183+
/// not dead code. Here we cache the result of computing that set of basic blocks.
184+
promotion_safe_blocks: Option<FxHashSet<BasicBlock>>,
178185
}
179186

180187
impl<'a, 'tcx> std::ops::Deref for Validator<'a, 'tcx> {
@@ -260,7 +267,9 @@ impl<'tcx> Validator<'_, 'tcx> {
260267
self.validate_rvalue(rhs)
261268
}
262269
Right(terminator) => match &terminator.kind {
263-
TerminatorKind::Call { func, args, .. } => self.validate_call(func, args),
270+
TerminatorKind::Call { func, args, .. } => {
271+
self.validate_call(func, args, loc.block)
272+
}
264273
TerminatorKind::Yield { .. } => Err(Unpromotable),
265274
kind => {
266275
span_bug!(terminator.source_info.span, "{:?} not promotable", kind);
@@ -587,53 +596,103 @@ impl<'tcx> Validator<'_, 'tcx> {
587596
Ok(())
588597
}
589598

599+
/// Computes the sets of blocks of this MIR that are definitely going to be executed
600+
/// if the function returns successfully. That makes it safe to promote calls in them
601+
/// that might fail.
602+
fn promotion_safe_blocks(body: &mir::Body<'tcx>) -> FxHashSet<BasicBlock> {
603+
let mut safe_blocks = FxHashSet::default();
604+
let mut safe_block = START_BLOCK;
605+
loop {
606+
safe_blocks.insert(safe_block);
607+
// Let's see if we can find another safe block.
608+
safe_block = match body.basic_blocks[safe_block].terminator().kind {
609+
TerminatorKind::Goto { target } => target,
610+
TerminatorKind::Call { target: Some(target), .. }
611+
| TerminatorKind::Drop { target, .. } => {
612+
// This calls a function or the destructor. `target` does not get executed if
613+
// the callee loops or panics. But in both cases the const already fails to
614+
// evaluate, so we are fine considering `target` a safe block for promotion.
615+
target
616+
}
617+
TerminatorKind::Assert { target, .. } => {
618+
// Similar to above, we only consider successful execution.
619+
target
620+
}
621+
_ => {
622+
// No next safe block.
623+
break;
624+
}
625+
};
626+
}
627+
safe_blocks
628+
}
629+
630+
/// Returns whether the block is "safe" for promotion, which means it cannot be dead code.
631+
/// We use this to avoid promoting operations that can fail in dead code.
632+
fn is_promotion_safe_block(&mut self, block: BasicBlock) -> bool {
633+
let body = self.body;
634+
let safe_blocks =
635+
self.promotion_safe_blocks.get_or_insert_with(|| Self::promotion_safe_blocks(body));
636+
safe_blocks.contains(&block)
637+
}
638+
590639
fn validate_call(
591640
&mut self,
592641
callee: &Operand<'tcx>,
593642
args: &[Spanned<Operand<'tcx>>],
643+
block: BasicBlock,
594644
) -> Result<(), Unpromotable> {
645+
// Validate the operands. If they fail, there's no question -- we cannot promote.
646+
self.validate_operand(callee)?;
647+
for arg in args {
648+
self.validate_operand(&arg.node)?;
649+
}
650+
651+
// Functions marked `#[rustc_promotable]` are explicitly allowed to be promoted, so we can
652+
// accept them at this point.
595653
let fn_ty = callee.ty(self.body, self.tcx);
654+
if let ty::FnDef(def_id, _) = *fn_ty.kind() {
655+
if self.tcx.is_promotable_const_fn(def_id) {
656+
return Ok(());
657+
}
658+
}
596659

597-
// Inside const/static items, we promote all (eligible) function calls.
598-
// Everywhere else, we require `#[rustc_promotable]` on the callee.
599-
let promote_all_const_fn = matches!(
660+
// Ideally, we'd stop here and reject the rest.
661+
// But for backward compatibility, we have to accept some promotion in const/static
662+
// initializers. Inline consts are explicitly excluded, they are more recent so we have no
663+
// backwards compatibility reason to allow more promotion inside of them.
664+
let promote_all_fn = matches!(
600665
self.const_kind,
601666
Some(hir::ConstContext::Static(_) | hir::ConstContext::Const { inline: false })
602667
);
603-
if !promote_all_const_fn {
604-
if let ty::FnDef(def_id, _) = *fn_ty.kind() {
605-
// Never promote runtime `const fn` calls of
606-
// functions without `#[rustc_promotable]`.
607-
if !self.tcx.is_promotable_const_fn(def_id) {
608-
return Err(Unpromotable);
609-
}
610-
}
668+
if !promote_all_fn {
669+
return Err(Unpromotable);
611670
}
612-
671+
// Make sure the callee is a `const fn`.
613672
let is_const_fn = match *fn_ty.kind() {
614673
ty::FnDef(def_id, _) => self.tcx.is_const_fn_raw(def_id),
615674
_ => false,
616675
};
617676
if !is_const_fn {
618677
return Err(Unpromotable);
619678
}
620-
621-
self.validate_operand(callee)?;
622-
for arg in args {
623-
self.validate_operand(&arg.node)?;
679+
// The problem is, this may promote calls to functions that panic.
680+
// We don't want to introduce compilation errors if there's a panic in a call in dead code.
681+
// So we ensure that this is not dead code.
682+
if !self.is_promotion_safe_block(block) {
683+
return Err(Unpromotable);
624684
}
625-
685+
// This passed all checks, so let's accept.
626686
Ok(())
627687
}
628688
}
629689

630-
// FIXME(eddyb) remove the differences for promotability in `static`, `const`, `const fn`.
631690
fn validate_candidates(
632691
ccx: &ConstCx<'_, '_>,
633692
temps: &mut IndexSlice<Local, TempState>,
634693
candidates: &[Candidate],
635694
) -> Vec<Candidate> {
636-
let mut validator = Validator { ccx, temps };
695+
let mut validator = Validator { ccx, temps, promotion_safe_blocks: None };
637696

638697
candidates
639698
.iter()
@@ -652,6 +711,10 @@ struct Promoter<'a, 'tcx> {
652711
/// If true, all nested temps are also kept in the
653712
/// source MIR, not moved to the promoted MIR.
654713
keep_original: bool,
714+
715+
/// If true, add the new const (the promoted) to the required_consts of the parent MIR.
716+
/// This is initially false and then set by the visitor when it encounters a `Call` terminator.
717+
add_to_required: bool,
655718
}
656719

657720
impl<'a, 'tcx> Promoter<'a, 'tcx> {
@@ -754,6 +817,10 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
754817
TerminatorKind::Call {
755818
mut func, mut args, call_source: desugar, fn_span, ..
756819
} => {
820+
// This promoted involves a function call, so it may fail to evaluate.
821+
// Let's make sure it is added to `required_consts` so that that failure cannot get lost.
822+
self.add_to_required = true;
823+
757824
self.visit_operand(&mut func, loc);
758825
for arg in &mut args {
759826
self.visit_operand(&mut arg.node, loc);
@@ -788,7 +855,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
788855

789856
fn promote_candidate(mut self, candidate: Candidate, next_promoted_id: usize) -> Body<'tcx> {
790857
let def = self.source.source.def_id();
791-
let mut rvalue = {
858+
let (mut rvalue, promoted_op) = {
792859
let promoted = &mut self.promoted;
793860
let promoted_id = Promoted::new(next_promoted_id);
794861
let tcx = self.tcx;
@@ -798,11 +865,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
798865
let args = tcx.erase_regions(GenericArgs::identity_for_item(tcx, def));
799866
let uneval = mir::UnevaluatedConst { def, args, promoted: Some(promoted_id) };
800867

801-
Operand::Constant(Box::new(ConstOperand {
802-
span,
803-
user_ty: None,
804-
const_: Const::Unevaluated(uneval, ty),
805-
}))
868+
ConstOperand { span, user_ty: None, const_: Const::Unevaluated(uneval, ty) }
806869
};
807870

808871
let blocks = self.source.basic_blocks.as_mut();
@@ -838,22 +901,26 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
838901
let promoted_ref = local_decls.push(promoted_ref);
839902
assert_eq!(self.temps.push(TempState::Unpromotable), promoted_ref);
840903

904+
let promoted_operand = promoted_operand(ref_ty, span);
841905
let promoted_ref_statement = Statement {
842906
source_info: statement.source_info,
843907
kind: StatementKind::Assign(Box::new((
844908
Place::from(promoted_ref),
845-
Rvalue::Use(promoted_operand(ref_ty, span)),
909+
Rvalue::Use(Operand::Constant(Box::new(promoted_operand))),
846910
))),
847911
};
848912
self.extra_statements.push((loc, promoted_ref_statement));
849913

850-
Rvalue::Ref(
851-
tcx.lifetimes.re_erased,
852-
*borrow_kind,
853-
Place {
854-
local: mem::replace(&mut place.local, promoted_ref),
855-
projection: List::empty(),
856-
},
914+
(
915+
Rvalue::Ref(
916+
tcx.lifetimes.re_erased,
917+
*borrow_kind,
918+
Place {
919+
local: mem::replace(&mut place.local, promoted_ref),
920+
projection: List::empty(),
921+
},
922+
),
923+
promoted_operand,
857924
)
858925
};
859926

@@ -865,6 +932,12 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
865932

866933
let span = self.promoted.span;
867934
self.assign(RETURN_PLACE, rvalue, span);
935+
936+
// Now that we did promotion, we know whether we'll want to add this to `required_consts`.
937+
if self.add_to_required {
938+
self.source.required_consts.push(promoted_op);
939+
}
940+
868941
self.promoted
869942
}
870943
}
@@ -924,6 +997,11 @@ fn promote_candidates<'tcx>(
924997
None,
925998
body.tainted_by_errors,
926999
);
1000+
// We keep `required_consts` of the new MIR body empty. All consts mentioned here have
1001+
// already been added to the parent MIR's `required_consts` (that is computed before
1002+
// promotion), and no matter where this promoted const ends up, our parent MIR must be
1003+
// somewhere in the reachable dependency chain so we can rely on its required consts being
1004+
// evaluated.
9271005
promoted.phase = MirPhase::Analysis(AnalysisPhase::Initial);
9281006

9291007
let promoter = Promoter {
@@ -933,6 +1011,7 @@ fn promote_candidates<'tcx>(
9331011
temps: &mut temps,
9341012
extra_statements: &mut extra_statements,
9351013
keep_original: false,
1014+
add_to_required: false,
9361015
};
9371016

9381017
let mut promoted = promoter.promote_candidate(candidate, promotions.len());

0 commit comments

Comments
 (0)