Skip to content

Commit 622667f

Browse files
committed
Auto merge of rust-lang#121557 - RalfJung:const-fn-call-promotion, r=oli-obk
restrict promotion of `const fn` calls We only promote them in `const`/`static` initializers, but even that is still unfortunate -- we still cannot add promoteds to required_consts. But we should add them there to make sure it's always okay to evaluate every const we encounter in a MIR body. That effort of not promoting things that can fail to evaluate is tracked in rust-lang#80619. These `const fn` calls are the last missing piece. So I propose that we do not promote const-fn calls in const when that may fail without the entire const failing, thereby completing rust-lang#80619. Unfortunately we can't just reject promoting these functions outright due to backwards compatibility. So let's see if we can find a hack that makes crater happy... For the record, this is the [crater analysis](rust-lang#80243 (comment)) from when I tried to entirely forbid this kind of promotion. It's a tiny amount of breakage and if we had a nice alternative for code like that, we could conceivably push it through... but sadly, inline const expressions are still blocked on t-lang concerns about post-monomorphization errors and we haven't yet figured out an implementation that can resolve those concerns. So we're forced to make progress via other means, such as terrible hacks like this. Attempt one: only promote calls on the "safe path" at the beginning of a MIR block. This is the path that starts at the start block and continues via gotos and calls, but stops at the first branch. If we had imposed this restriction before stabilizing `if` and `match` in `const`, this would have definitely been sufficient... EDIT: Turns out that works. :) **Here's the t-lang [nomination comment](rust-lang#121557 (comment) And here's the [FCP comment](rust-lang#121557 (comment)). r? `@oli-obk`
2 parents 40dcd79 + 85433c4 commit 622667f

20 files changed

+343
-295
lines changed

compiler/rustc_const_eval/src/const_eval/dummy_machine.rs

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ impl<'mir, 'tcx: 'mir> interpret::Machine<'mir, 'tcx> for DummyMachine {
4646
type MemoryKind = !;
4747
const PANIC_ON_ALLOC_FAIL: bool = true;
4848

49+
// We want to just eval random consts in the program, so `eval_mir_const` can fail.
50+
const ALL_CONSTS_ARE_PRECHECKED: bool = false;
51+
4952
#[inline(always)]
5053
fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
5154
false // no reason to enforce alignment

compiler/rustc_const_eval/src/interpret/eval_context.rs

+11-11
Original file line numberDiff line numberDiff line change
@@ -822,15 +822,13 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
822822
self.stack_mut().push(frame);
823823

824824
// Make sure all the constants required by this frame evaluate successfully (post-monomorphization check).
825-
if M::POST_MONO_CHECKS {
826-
for &const_ in &body.required_consts {
827-
let c = self
828-
.instantiate_from_current_frame_and_normalize_erasing_regions(const_.const_)?;
829-
c.eval(*self.tcx, self.param_env, const_.span).map_err(|err| {
830-
err.emit_note(*self.tcx);
831-
err
832-
})?;
833-
}
825+
for &const_ in &body.required_consts {
826+
let c =
827+
self.instantiate_from_current_frame_and_normalize_erasing_regions(const_.const_)?;
828+
c.eval(*self.tcx, self.param_env, const_.span).map_err(|err| {
829+
err.emit_note(*self.tcx);
830+
err
831+
})?;
834832
}
835833

836834
// done
@@ -1181,8 +1179,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
11811179
) -> InterpResult<'tcx, OpTy<'tcx, M::Provenance>> {
11821180
M::eval_mir_constant(self, *val, span, layout, |ecx, val, span, layout| {
11831181
let const_val = val.eval(*ecx.tcx, ecx.param_env, span).map_err(|err| {
1184-
// FIXME: somehow this is reachable even when POST_MONO_CHECKS is on.
1185-
// Are we not always populating `required_consts`?
1182+
if M::ALL_CONSTS_ARE_PRECHECKED && !matches!(err, ErrorHandled::TooGeneric(..)) {
1183+
// Looks like the const is not captued by `required_consts`, that's bad.
1184+
bug!("interpret const eval failure of {val:?} which is not in required_consts");
1185+
}
11861186
err.emit_note(*ecx.tcx);
11871187
err
11881188
})?;

compiler/rustc_const_eval/src/interpret/machine.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
140140
/// Should the machine panic on allocation failures?
141141
const PANIC_ON_ALLOC_FAIL: bool;
142142

143-
/// Should post-monomorphization checks be run when a stack frame is pushed?
144-
const POST_MONO_CHECKS: bool = true;
143+
/// Determines whether `eval_mir_constant` can never fail because all required consts have
144+
/// already been checked before.
145+
const ALL_CONSTS_ARE_PRECHECKED: bool = true;
145146

146147
/// Whether memory accesses should be alignment-checked.
147148
fn enforce_alignment(ecx: &InterpCx<'mir, 'tcx, Self>) -> bool;

compiler/rustc_middle/src/mir/consts.rs

+14
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,20 @@ impl<'tcx> Const<'tcx> {
238238
}
239239
}
240240

241+
/// Determines whether we need to add this const to `required_consts`. This is the case if and
242+
/// only if evaluating it may error.
243+
#[inline]
244+
pub fn is_required_const(&self) -> bool {
245+
match self {
246+
Const::Ty(c) => match c.kind() {
247+
ty::ConstKind::Value(_) => false, // already a value, cannot error
248+
_ => true,
249+
},
250+
Const::Val(..) => false, // already a value, cannot error
251+
Const::Unevaluated(..) => true,
252+
}
253+
}
254+
241255
#[inline]
242256
pub fn try_to_scalar(self) -> Option<Scalar> {
243257
match self {

compiler/rustc_mir_transform/src/inline.rs

+9-14
Original file line numberDiff line numberDiff line change
@@ -720,18 +720,12 @@ impl<'tcx> Inliner<'tcx> {
720720
kind: TerminatorKind::Goto { target: integrator.map_block(START_BLOCK) },
721721
});
722722

723-
// Copy only unevaluated constants from the callee_body into the caller_body.
724-
// Although we are only pushing `ConstKind::Unevaluated` consts to
725-
// `required_consts`, here we may not only have `ConstKind::Unevaluated`
726-
// because we are calling `instantiate_and_normalize_erasing_regions`.
727-
caller_body.required_consts.extend(callee_body.required_consts.iter().copied().filter(
728-
|&ct| match ct.const_ {
729-
Const::Ty(_) => {
730-
bug!("should never encounter ty::UnevaluatedConst in `required_consts`")
731-
}
732-
Const::Val(..) | Const::Unevaluated(..) => true,
733-
},
734-
));
723+
// Copy required constants from the callee_body into the caller_body. Although we are only
724+
// pushing unevaluated consts to `required_consts`, here they may have been evaluated
725+
// because we are calling `instantiate_and_normalize_erasing_regions` -- so we filter again.
726+
caller_body.required_consts.extend(
727+
callee_body.required_consts.into_iter().filter(|ct| ct.const_.is_required_const()),
728+
);
735729
// Now that we incorporated the callee's `required_consts`, we can remove the callee from
736730
// `mentioned_items` -- but we have to take their `mentioned_items` in return. This does
737731
// some extra work here to save the monomorphization collector work later. It helps a lot,
@@ -747,8 +741,9 @@ impl<'tcx> Inliner<'tcx> {
747741
caller_body.mentioned_items.remove(idx);
748742
caller_body.mentioned_items.extend(callee_body.mentioned_items);
749743
} else {
750-
// If we can't find the callee, there's no point in adding its items.
751-
// Probably it already got removed by being inlined elsewhere in the same function.
744+
// If we can't find the callee, there's no point in adding its items. Probably it
745+
// already got removed by being inlined elsewhere in the same function, so we already
746+
// took its items.
752747
}
753748
}
754749

compiler/rustc_mir_transform/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@ fn mir_promoted(
333333
body.tainted_by_errors = Some(error_reported);
334334
}
335335

336+
// Collect `required_consts` *before* promotion, so if there are any consts being promoted
337+
// we still add them to the list in the outer MIR body.
336338
let mut required_consts = Vec::new();
337339
let mut required_consts_visitor = RequiredConstsVisitor::new(&mut required_consts);
338340
for (bb, bb_data) in traversal::reverse_postorder(&body) {

compiler/rustc_mir_transform/src/promote_consts.rs

+117-34
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//! move analysis runs after promotion on broken MIR.
1414
1515
use either::{Left, Right};
16+
use rustc_data_structures::fx::FxHashSet;
1617
use rustc_hir as hir;
1718
use rustc_middle::mir;
1819
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
@@ -175,6 +176,12 @@ fn collect_temps_and_candidates<'tcx>(
175176
struct Validator<'a, 'tcx> {
176177
ccx: &'a ConstCx<'a, 'tcx>,
177178
temps: &'a mut IndexSlice<Local, TempState>,
179+
/// For backwards compatibility, we are promoting function calls in `const`/`static`
180+
/// initializers. But we want to avoid evaluating code that might panic and that otherwise would
181+
/// not have been evaluated, so we only promote such calls in basic blocks that are guaranteed
182+
/// to execute. In other words, we only promote such calls in basic blocks that are definitely
183+
/// not dead code. Here we cache the result of computing that set of basic blocks.
184+
promotion_safe_blocks: Option<FxHashSet<BasicBlock>>,
178185
}
179186

180187
impl<'a, 'tcx> std::ops::Deref for Validator<'a, 'tcx> {
@@ -260,7 +267,9 @@ impl<'tcx> Validator<'_, 'tcx> {
260267
self.validate_rvalue(rhs)
261268
}
262269
Right(terminator) => match &terminator.kind {
263-
TerminatorKind::Call { func, args, .. } => self.validate_call(func, args),
270+
TerminatorKind::Call { func, args, .. } => {
271+
self.validate_call(func, args, loc.block)
272+
}
264273
TerminatorKind::Yield { .. } => Err(Unpromotable),
265274
kind => {
266275
span_bug!(terminator.source_info.span, "{:?} not promotable", kind);
@@ -588,53 +597,103 @@ impl<'tcx> Validator<'_, 'tcx> {
588597
Ok(())
589598
}
590599

600+
/// Computes the sets of blocks of this MIR that are definitely going to be executed
601+
/// if the function returns successfully. That makes it safe to promote calls in them
602+
/// that might fail.
603+
fn promotion_safe_blocks(body: &mir::Body<'tcx>) -> FxHashSet<BasicBlock> {
604+
let mut safe_blocks = FxHashSet::default();
605+
let mut safe_block = START_BLOCK;
606+
loop {
607+
safe_blocks.insert(safe_block);
608+
// Let's see if we can find another safe block.
609+
safe_block = match body.basic_blocks[safe_block].terminator().kind {
610+
TerminatorKind::Goto { target } => target,
611+
TerminatorKind::Call { target: Some(target), .. }
612+
| TerminatorKind::Drop { target, .. } => {
613+
// This calls a function or the destructor. `target` does not get executed if
614+
// the callee loops or panics. But in both cases the const already fails to
615+
// evaluate, so we are fine considering `target` a safe block for promotion.
616+
target
617+
}
618+
TerminatorKind::Assert { target, .. } => {
619+
// Similar to above, we only consider successful execution.
620+
target
621+
}
622+
_ => {
623+
// No next safe block.
624+
break;
625+
}
626+
};
627+
}
628+
safe_blocks
629+
}
630+
631+
/// Returns whether the block is "safe" for promotion, which means it cannot be dead code.
632+
/// We use this to avoid promoting operations that can fail in dead code.
633+
fn is_promotion_safe_block(&mut self, block: BasicBlock) -> bool {
634+
let body = self.body;
635+
let safe_blocks =
636+
self.promotion_safe_blocks.get_or_insert_with(|| Self::promotion_safe_blocks(body));
637+
safe_blocks.contains(&block)
638+
}
639+
591640
fn validate_call(
592641
&mut self,
593642
callee: &Operand<'tcx>,
594643
args: &[Spanned<Operand<'tcx>>],
644+
block: BasicBlock,
595645
) -> Result<(), Unpromotable> {
646+
// Validate the operands. If they fail, there's no question -- we cannot promote.
647+
self.validate_operand(callee)?;
648+
for arg in args {
649+
self.validate_operand(&arg.node)?;
650+
}
651+
652+
// Functions marked `#[rustc_promotable]` are explicitly allowed to be promoted, so we can
653+
// accept them at this point.
596654
let fn_ty = callee.ty(self.body, self.tcx);
655+
if let ty::FnDef(def_id, _) = *fn_ty.kind() {
656+
if self.tcx.is_promotable_const_fn(def_id) {
657+
return Ok(());
658+
}
659+
}
597660

598-
// Inside const/static items, we promote all (eligible) function calls.
599-
// Everywhere else, we require `#[rustc_promotable]` on the callee.
600-
let promote_all_const_fn = matches!(
661+
// Ideally, we'd stop here and reject the rest.
662+
// But for backward compatibility, we have to accept some promotion in const/static
663+
// initializers. Inline consts are explicitly excluded, they are more recent so we have no
664+
// backwards compatibility reason to allow more promotion inside of them.
665+
let promote_all_fn = matches!(
601666
self.const_kind,
602667
Some(hir::ConstContext::Static(_) | hir::ConstContext::Const { inline: false })
603668
);
604-
if !promote_all_const_fn {
605-
if let ty::FnDef(def_id, _) = *fn_ty.kind() {
606-
// Never promote runtime `const fn` calls of
607-
// functions without `#[rustc_promotable]`.
608-
if !self.tcx.is_promotable_const_fn(def_id) {
609-
return Err(Unpromotable);
610-
}
611-
}
669+
if !promote_all_fn {
670+
return Err(Unpromotable);
612671
}
613-
672+
// Make sure the callee is a `const fn`.
614673
let is_const_fn = match *fn_ty.kind() {
615674
ty::FnDef(def_id, _) => self.tcx.is_const_fn_raw(def_id),
616675
_ => false,
617676
};
618677
if !is_const_fn {
619678
return Err(Unpromotable);
620679
}
621-
622-
self.validate_operand(callee)?;
623-
for arg in args {
624-
self.validate_operand(&arg.node)?;
680+
// The problem is, this may promote calls to functions that panic.
681+
// We don't want to introduce compilation errors if there's a panic in a call in dead code.
682+
// So we ensure that this is not dead code.
683+
if !self.is_promotion_safe_block(block) {
684+
return Err(Unpromotable);
625685
}
626-
686+
// This passed all checks, so let's accept.
627687
Ok(())
628688
}
629689
}
630690

631-
// FIXME(eddyb) remove the differences for promotability in `static`, `const`, `const fn`.
632691
fn validate_candidates(
633692
ccx: &ConstCx<'_, '_>,
634693
temps: &mut IndexSlice<Local, TempState>,
635694
candidates: &[Candidate],
636695
) -> Vec<Candidate> {
637-
let mut validator = Validator { ccx, temps };
696+
let mut validator = Validator { ccx, temps, promotion_safe_blocks: None };
638697

639698
candidates
640699
.iter()
@@ -653,6 +712,10 @@ struct Promoter<'a, 'tcx> {
653712
/// If true, all nested temps are also kept in the
654713
/// source MIR, not moved to the promoted MIR.
655714
keep_original: bool,
715+
716+
/// If true, add the new const (the promoted) to the required_consts of the parent MIR.
717+
/// This is initially false and then set by the visitor when it encounters a `Call` terminator.
718+
add_to_required: bool,
656719
}
657720

658721
impl<'a, 'tcx> Promoter<'a, 'tcx> {
@@ -755,6 +818,10 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
755818
TerminatorKind::Call {
756819
mut func, mut args, call_source: desugar, fn_span, ..
757820
} => {
821+
// This promoted involves a function call, so it may fail to evaluate.
822+
// Let's make sure it is added to `required_consts` so that that failure cannot get lost.
823+
self.add_to_required = true;
824+
758825
self.visit_operand(&mut func, loc);
759826
for arg in &mut args {
760827
self.visit_operand(&mut arg.node, loc);
@@ -789,7 +856,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
789856

790857
fn promote_candidate(mut self, candidate: Candidate, next_promoted_id: usize) -> Body<'tcx> {
791858
let def = self.source.source.def_id();
792-
let mut rvalue = {
859+
let (mut rvalue, promoted_op) = {
793860
let promoted = &mut self.promoted;
794861
let promoted_id = Promoted::new(next_promoted_id);
795862
let tcx = self.tcx;
@@ -799,11 +866,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
799866
let args = tcx.erase_regions(GenericArgs::identity_for_item(tcx, def));
800867
let uneval = mir::UnevaluatedConst { def, args, promoted: Some(promoted_id) };
801868

802-
Operand::Constant(Box::new(ConstOperand {
803-
span,
804-
user_ty: None,
805-
const_: Const::Unevaluated(uneval, ty),
806-
}))
869+
ConstOperand { span, user_ty: None, const_: Const::Unevaluated(uneval, ty) }
807870
};
808871

809872
let blocks = self.source.basic_blocks.as_mut();
@@ -836,22 +899,26 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
836899
let promoted_ref = local_decls.push(promoted_ref);
837900
assert_eq!(self.temps.push(TempState::Unpromotable), promoted_ref);
838901

902+
let promoted_operand = promoted_operand(ref_ty, span);
839903
let promoted_ref_statement = Statement {
840904
source_info: statement.source_info,
841905
kind: StatementKind::Assign(Box::new((
842906
Place::from(promoted_ref),
843-
Rvalue::Use(promoted_operand(ref_ty, span)),
907+
Rvalue::Use(Operand::Constant(Box::new(promoted_operand))),
844908
))),
845909
};
846910
self.extra_statements.push((loc, promoted_ref_statement));
847911

848-
Rvalue::Ref(
849-
tcx.lifetimes.re_erased,
850-
*borrow_kind,
851-
Place {
852-
local: mem::replace(&mut place.local, promoted_ref),
853-
projection: List::empty(),
854-
},
912+
(
913+
Rvalue::Ref(
914+
tcx.lifetimes.re_erased,
915+
*borrow_kind,
916+
Place {
917+
local: mem::replace(&mut place.local, promoted_ref),
918+
projection: List::empty(),
919+
},
920+
),
921+
promoted_operand,
855922
)
856923
};
857924

@@ -863,6 +930,12 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
863930

864931
let span = self.promoted.span;
865932
self.assign(RETURN_PLACE, rvalue, span);
933+
934+
// Now that we did promotion, we know whether we'll want to add this to `required_consts`.
935+
if self.add_to_required {
936+
self.source.required_consts.push(promoted_op);
937+
}
938+
866939
self.promoted
867940
}
868941
}
@@ -878,6 +951,14 @@ impl<'a, 'tcx> MutVisitor<'tcx> for Promoter<'a, 'tcx> {
878951
*local = self.promote_temp(*local);
879952
}
880953
}
954+
955+
fn visit_constant(&mut self, constant: &mut ConstOperand<'tcx>, _location: Location) {
956+
if constant.const_.is_required_const() {
957+
self.promoted.required_consts.push(*constant);
958+
}
959+
960+
// Skipping `super_constant` as the visitor is otherwise only looking for locals.
961+
}
881962
}
882963

883964
fn promote_candidates<'tcx>(
@@ -931,8 +1012,10 @@ fn promote_candidates<'tcx>(
9311012
temps: &mut temps,
9321013
extra_statements: &mut extra_statements,
9331014
keep_original: false,
1015+
add_to_required: false,
9341016
};
9351017

1018+
// `required_consts` of the promoted itself gets filled while building the MIR body.
9361019
let mut promoted = promoter.promote_candidate(candidate, promotions.len());
9371020
promoted.source.promoted = Some(promotions.next_index());
9381021
promotions.push(promoted);

0 commit comments

Comments
 (0)