Skip to content

Commit 848a387

Browse files
committed
Auto merge of #116482 - matthewjasper:thir-unsafeck-inline-constants, r=b-naber
Fix inline const pattern unsafety checking in THIR Fix THIR unsafety checking of inline constants. - Steal THIR in THIR unsafety checking (if enabled) instead of MIR lowering. - Represent inline constants in THIR patterns.
2 parents df871fb + 8aea0e9 commit 848a387

24 files changed

+245
-62
lines changed

compiler/rustc_interface/src/passes.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,16 @@ fn analysis(tcx: TyCtxt<'_>, (): ()) -> Result<()> {
775775
rustc_hir_analysis::check_crate(tcx)?;
776776

777777
sess.time("MIR_borrow_checking", || {
778-
tcx.hir().par_body_owners(|def_id| tcx.ensure().mir_borrowck(def_id));
778+
tcx.hir().par_body_owners(|def_id| {
779+
// Run THIR unsafety check because it's responsible for stealing
780+
// and deallocating THIR when enabled.
781+
tcx.ensure().thir_check_unsafety(def_id);
782+
tcx.ensure().mir_borrowck(def_id)
783+
});
779784
});
780785

781786
sess.time("MIR_effect_checking", || {
782787
for def_id in tcx.hir().body_owners() {
783-
tcx.ensure().thir_check_unsafety(def_id);
784788
if !tcx.sess.opts.unstable_opts.thir_unsafeck {
785789
rustc_mir_transform::check_unsafety::check_unsafety(tcx, def_id);
786790
}

compiler/rustc_middle/src/thir.rs

+21-1
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,8 @@ impl<'tcx> Pat<'tcx> {
636636
Wild | Range(..) | Binding { subpattern: None, .. } | Constant { .. } | Error(_) => {}
637637
AscribeUserType { subpattern, .. }
638638
| Binding { subpattern: Some(subpattern), .. }
639-
| Deref { subpattern } => subpattern.walk_(it),
639+
| Deref { subpattern }
640+
| InlineConstant { subpattern, .. } => subpattern.walk_(it),
640641
Leaf { subpatterns } | Variant { subpatterns, .. } => {
641642
subpatterns.iter().for_each(|field| field.pattern.walk_(it))
642643
}
@@ -764,6 +765,22 @@ pub enum PatKind<'tcx> {
764765
value: mir::Const<'tcx>,
765766
},
766767

768+
/// Inline constant found while lowering a pattern.
769+
InlineConstant {
770+
/// [LocalDefId] of the constant, we need this so that we have a
771+
/// reference that can be used by unsafety checking to visit nested
772+
/// unevaluated constants.
773+
def: LocalDefId,
774+
/// If the inline constant is used in a range pattern, this subpattern
775+
/// represents the range (if both ends are inline constants, there will
776+
/// be multiple InlineConstant wrappers).
777+
///
778+
/// Otherwise, the actual pattern that the constant lowered to. As with
779+
/// other constants, inline constants are matched structurally where
780+
/// possible.
781+
subpattern: Box<Pat<'tcx>>,
782+
},
783+
767784
Range(Box<PatRange<'tcx>>),
768785

769786
/// Matches against a slice, checking the length and extracting elements.
@@ -924,6 +941,9 @@ impl<'tcx> fmt::Display for Pat<'tcx> {
924941
write!(f, "{subpattern}")
925942
}
926943
PatKind::Constant { value } => write!(f, "{value}"),
944+
PatKind::InlineConstant { def: _, ref subpattern } => {
945+
write!(f, "{} (from inline const)", subpattern)
946+
}
927947
PatKind::Range(box PatRange { lo, hi, end }) => {
928948
write!(f, "{lo}")?;
929949
write!(f, "{end}")?;

compiler/rustc_middle/src/thir/visit.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -233,16 +233,17 @@ pub fn walk_pat<'a, 'tcx: 'a, V: Visitor<'a, 'tcx>>(visitor: &mut V, pat: &Pat<'
233233
}
234234
}
235235
Constant { value: _ } => {}
236+
InlineConstant { def: _, subpattern } => visitor.visit_pat(subpattern),
236237
Range(_) => {}
237238
Slice { prefix, slice, suffix } | Array { prefix, slice, suffix } => {
238239
for subpattern in prefix.iter() {
239-
visitor.visit_pat(&subpattern);
240+
visitor.visit_pat(subpattern);
240241
}
241242
if let Some(pat) = slice {
242-
visitor.visit_pat(&pat);
243+
visitor.visit_pat(pat);
243244
}
244245
for subpattern in suffix.iter() {
245-
visitor.visit_pat(&subpattern);
246+
visitor.visit_pat(subpattern);
246247
}
247248
}
248249
Or { pats } => {

compiler/rustc_mir_build/src/build/matches/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
847847
self.visit_primary_bindings(subpattern, subpattern_user_ty, f)
848848
}
849849

850+
PatKind::InlineConstant { ref subpattern, .. } => {
851+
self.visit_primary_bindings(subpattern, pattern_user_ty.clone(), f)
852+
}
853+
850854
PatKind::Leaf { ref subpatterns } => {
851855
for subpattern in subpatterns {
852856
let subpattern_user_ty = pattern_user_ty.clone().leaf(subpattern.field);

compiler/rustc_mir_build/src/build/matches/simplify.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
204204
Err(match_pair)
205205
}
206206

207+
PatKind::InlineConstant { subpattern: ref pattern, def: _ } => {
208+
candidate.match_pairs.push(MatchPair::new(match_pair.place, pattern, self));
209+
210+
Ok(())
211+
}
212+
207213
PatKind::Range(box PatRange { lo, hi, end }) => {
208214
let (range, bias) = match *lo.ty().kind() {
209215
ty::Char => {
@@ -229,8 +235,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
229235
// correct the comparison. This is achieved by XORing with a bias (see
230236
// pattern/_match.rs for another pertinent example of this pattern).
231237
//
232-
// Also, for performance, it's important to only do the second `try_to_bits` if
233-
// necessary.
238+
// Also, for performance, it's important to only do the second
239+
// `try_to_bits` if necessary.
234240
let lo = lo.try_to_bits(sz).unwrap() ^ bias;
235241
if lo <= min {
236242
let hi = hi.try_to_bits(sz).unwrap() ^ bias;

compiler/rustc_mir_build/src/build/matches/test.rs

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
7373
PatKind::Or { .. } => bug!("or-patterns should have already been handled"),
7474

7575
PatKind::AscribeUserType { .. }
76+
| PatKind::InlineConstant { .. }
7677
| PatKind::Array { .. }
7778
| PatKind::Wild
7879
| PatKind::Binding { .. }
@@ -111,6 +112,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
111112
| PatKind::Or { .. }
112113
| PatKind::Binding { .. }
113114
| PatKind::AscribeUserType { .. }
115+
| PatKind::InlineConstant { .. }
114116
| PatKind::Leaf { .. }
115117
| PatKind::Deref { .. }
116118
| PatKind::Error(_) => {

compiler/rustc_mir_build/src/build/mod.rs

+15-11
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,7 @@ pub(crate) fn closure_saved_names_of_captured_variables<'tcx>(
5353
}
5454

5555
/// Construct the MIR for a given `DefId`.
56-
fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
57-
// Ensure unsafeck and abstract const building is ran before we steal the THIR.
58-
tcx.ensure_with_value()
59-
.thir_check_unsafety(tcx.typeck_root_def_id(def.to_def_id()).expect_local());
56+
fn mir_build<'tcx>(tcx: TyCtxt<'tcx>, def: LocalDefId) -> Body<'tcx> {
6057
tcx.ensure_with_value().thir_abstract_const(def);
6158
if let Err(e) = tcx.check_match(def) {
6259
return construct_error(tcx, def, e);
@@ -65,20 +62,27 @@ fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
6562
let body = match tcx.thir_body(def) {
6663
Err(error_reported) => construct_error(tcx, def, error_reported),
6764
Ok((thir, expr)) => {
68-
// We ran all queries that depended on THIR at the beginning
69-
// of `mir_build`, so now we can steal it
70-
let thir = thir.steal();
65+
let build_mir = |thir: &Thir<'tcx>| match thir.body_type {
66+
thir::BodyTy::Fn(fn_sig) => construct_fn(tcx, def, thir, expr, fn_sig),
67+
thir::BodyTy::Const(ty) => construct_const(tcx, def, thir, expr, ty),
68+
};
7169

72-
tcx.ensure().check_match(def);
7370
// this must run before MIR dump, because
7471
// "not all control paths return a value" is reported here.
7572
//
7673
// maybe move the check to a MIR pass?
7774
tcx.ensure().check_liveness(def);
7875

79-
match thir.body_type {
80-
thir::BodyTy::Fn(fn_sig) => construct_fn(tcx, def, &thir, expr, fn_sig),
81-
thir::BodyTy::Const(ty) => construct_const(tcx, def, &thir, expr, ty),
76+
if tcx.sess.opts.unstable_opts.thir_unsafeck {
77+
// Don't steal here if THIR unsafeck is being used. Instead
78+
// steal in unsafeck. This is so that pattern inline constants
79+
// can be evaluated as part of building the THIR of the parent
80+
// function without a cycle.
81+
build_mir(&thir.borrow())
82+
} else {
83+
// We ran all queries that depended on THIR at the beginning
84+
// of `mir_build`, so now we can steal it
85+
build_mir(&thir.steal())
8286
}
8387
}
8488
};

compiler/rustc_mir_build/src/check_unsafety.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
124124
/// Handle closures/coroutines/inline-consts, which is unsafecked with their parent body.
125125
fn visit_inner_body(&mut self, def: LocalDefId) {
126126
if let Ok((inner_thir, expr)) = self.tcx.thir_body(def) {
127-
let inner_thir = &inner_thir.borrow();
127+
// Runs all other queries that depend on THIR.
128+
self.tcx.ensure_with_value().mir_built(def);
129+
let inner_thir = &inner_thir.steal();
128130
let hir_context = self.tcx.hir().local_def_id_to_hir_id(def);
129131
let mut inner_visitor = UnsafetyVisitor { thir: inner_thir, hir_context, ..*self };
130132
inner_visitor.visit_expr(&inner_thir[expr]);
@@ -224,6 +226,7 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
224226
PatKind::Wild |
225227
// these just wrap other patterns
226228
PatKind::Or { .. } |
229+
PatKind::InlineConstant { .. } |
227230
PatKind::AscribeUserType { .. } |
228231
PatKind::Error(_) => {}
229232
}
@@ -277,6 +280,9 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
277280
visit::walk_pat(self, pat);
278281
self.inside_adt = old_inside_adt;
279282
}
283+
PatKind::InlineConstant { def, .. } => {
284+
self.visit_inner_body(*def);
285+
}
280286
_ => {
281287
visit::walk_pat(self, pat);
282288
}
@@ -788,7 +794,9 @@ pub fn thir_check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
788794
}
789795

790796
let Ok((thir, expr)) = tcx.thir_body(def) else { return };
791-
let thir = &thir.borrow();
797+
// Runs all other queries that depend on THIR.
798+
tcx.ensure_with_value().mir_built(def);
799+
let thir = &thir.steal();
792800
// If `thir` is empty, a type error occurred, skip this body.
793801
if thir.exprs.is_empty() {
794802
return;

compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,8 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> {
13261326
let ctor;
13271327
let fields;
13281328
match &pat.kind {
1329-
PatKind::AscribeUserType { subpattern, .. } => return mkpat(subpattern),
1329+
PatKind::AscribeUserType { subpattern, .. }
1330+
| PatKind::InlineConstant { subpattern, .. } => return mkpat(subpattern),
13301331
PatKind::Binding { subpattern: Some(subpat), .. } => return mkpat(subpat),
13311332
PatKind::Binding { subpattern: None, .. } | PatKind::Wild => {
13321333
ctor = Wildcard;

compiler/rustc_mir_build/src/thir/pattern/mod.rs

+30-16
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use rustc_middle::ty::{
2727
self, AdtDef, CanonicalUserTypeAnnotation, GenericArg, GenericArgsRef, Region, Ty, TyCtxt,
2828
TypeVisitableExt, UserType,
2929
};
30+
use rustc_span::def_id::LocalDefId;
3031
use rustc_span::{ErrorGuaranteed, Span, Symbol};
3132
use rustc_target::abi::{FieldIdx, Integer};
3233

@@ -88,15 +89,21 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
8889
fn lower_pattern_range_endpoint(
8990
&mut self,
9091
expr: Option<&'tcx hir::Expr<'tcx>>,
91-
) -> Result<(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>), ErrorGuaranteed> {
92+
) -> Result<
93+
(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>, Option<LocalDefId>),
94+
ErrorGuaranteed,
95+
> {
9296
match expr {
93-
None => Ok((None, None)),
97+
None => Ok((None, None, None)),
9498
Some(expr) => {
95-
let (kind, ascr) = match self.lower_lit(expr) {
99+
let (kind, ascr, inline_const) = match self.lower_lit(expr) {
100+
PatKind::InlineConstant { subpattern, def } => {
101+
(subpattern.kind, None, Some(def))
102+
}
96103
PatKind::AscribeUserType { ascription, subpattern: box Pat { kind, .. } } => {
97-
(kind, Some(ascription))
104+
(kind, Some(ascription), None)
98105
}
99-
kind => (kind, None),
106+
kind => (kind, None, None),
100107
};
101108
let value = if let PatKind::Constant { value } = kind {
102109
value
@@ -106,7 +113,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
106113
);
107114
return Err(self.tcx.sess.delay_span_bug(expr.span, msg));
108115
};
109-
Ok((Some(value), ascr))
116+
Ok((Some(value), ascr, inline_const))
110117
}
111118
}
112119
}
@@ -177,8 +184,8 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
177184
return Err(self.tcx.sess.delay_span_bug(span, msg));
178185
}
179186

180-
let (lo, lo_ascr) = self.lower_pattern_range_endpoint(lo_expr)?;
181-
let (hi, hi_ascr) = self.lower_pattern_range_endpoint(hi_expr)?;
187+
let (lo, lo_ascr, lo_inline) = self.lower_pattern_range_endpoint(lo_expr)?;
188+
let (hi, hi_ascr, hi_inline) = self.lower_pattern_range_endpoint(hi_expr)?;
182189

183190
let lo = lo.unwrap_or_else(|| {
184191
// Unwrap is ok because the type is known to be numeric.
@@ -237,6 +244,12 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
237244
};
238245
}
239246
}
247+
for inline_const in [lo_inline, hi_inline] {
248+
if let Some(def) = inline_const {
249+
kind =
250+
PatKind::InlineConstant { def, subpattern: Box::new(Pat { span, ty, kind }) };
251+
}
252+
}
240253
Ok(kind)
241254
}
242255

@@ -599,11 +612,9 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
599612
// const eval path below.
600613
// FIXME: investigate the performance impact of removing this.
601614
let lit_input = match expr.kind {
602-
hir::ExprKind::Lit(ref lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }),
603-
hir::ExprKind::Unary(hir::UnOp::Neg, ref expr) => match expr.kind {
604-
hir::ExprKind::Lit(ref lit) => {
605-
Some(LitToConstInput { lit: &lit.node, ty, neg: true })
606-
}
615+
hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }),
616+
hir::ExprKind::Unary(hir::UnOp::Neg, expr) => match expr.kind {
617+
hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: true }),
607618
_ => None,
608619
},
609620
_ => None,
@@ -633,13 +644,13 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
633644
if let Ok(Some(valtree)) =
634645
self.tcx.const_eval_resolve_for_typeck(self.param_env, ct, Some(span))
635646
{
636-
self.const_to_pat(
647+
let subpattern = self.const_to_pat(
637648
Const::Ty(ty::Const::new_value(self.tcx, valtree, ty)),
638649
id,
639650
span,
640651
None,
641-
)
642-
.kind
652+
);
653+
PatKind::InlineConstant { subpattern, def: def_id }
643654
} else {
644655
// If that fails, convert it to an opaque constant pattern.
645656
match tcx.const_eval_resolve(self.param_env, uneval, Some(span)) {
@@ -822,6 +833,9 @@ impl<'tcx> PatternFoldable<'tcx> for PatKind<'tcx> {
822833
PatKind::Deref { subpattern: subpattern.fold_with(folder) }
823834
}
824835
PatKind::Constant { value } => PatKind::Constant { value },
836+
PatKind::InlineConstant { def, subpattern: ref pattern } => {
837+
PatKind::InlineConstant { def, subpattern: pattern.fold_with(folder) }
838+
}
825839
PatKind::Range(ref range) => PatKind::Range(range.clone()),
826840
PatKind::Slice { ref prefix, ref slice, ref suffix } => PatKind::Slice {
827841
prefix: prefix.fold_with(folder),

compiler/rustc_mir_build/src/thir/print.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
692692
}
693693
PatKind::Deref { subpattern } => {
694694
print_indented!(self, "Deref { ", depth_lvl + 1);
695-
print_indented!(self, "subpattern: ", depth_lvl + 2);
695+
print_indented!(self, "subpattern:", depth_lvl + 2);
696696
self.print_pat(subpattern, depth_lvl + 2);
697697
print_indented!(self, "}", depth_lvl + 1);
698698
}
@@ -701,6 +701,13 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
701701
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
702702
print_indented!(self, "}", depth_lvl + 1);
703703
}
704+
PatKind::InlineConstant { def, subpattern } => {
705+
print_indented!(self, "InlineConstant {", depth_lvl + 1);
706+
print_indented!(self, format!("def: {:?}", def), depth_lvl + 2);
707+
print_indented!(self, "subpattern:", depth_lvl + 2);
708+
self.print_pat(subpattern, depth_lvl + 2);
709+
print_indented!(self, "}", depth_lvl + 1);
710+
}
704711
PatKind::Range(pat_range) => {
705712
print_indented!(self, format!("Range ( {:?} )", pat_range), depth_lvl + 1);
706713
}

tests/ui/async-await/async-unsafe-fn-call-in-safe.mir.stderr

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ LL | S::f();
2323
= note: consult the function's documentation for information on how to avoid undefined behavior
2424

2525
error[E0133]: call to unsafe function is unsafe and requires unsafe function or block
26-
--> $DIR/async-unsafe-fn-call-in-safe.rs:24:5
26+
--> $DIR/async-unsafe-fn-call-in-safe.rs:26:5
2727
|
2828
LL | f();
2929
| ^^^ call to unsafe function

tests/ui/async-await/async-unsafe-fn-call-in-safe.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ async fn g() {
2020
}
2121

2222
fn main() {
23-
S::f(); //[mir]~ ERROR call to unsafe function is unsafe
24-
f(); //[mir]~ ERROR call to unsafe function is unsafe
23+
S::f();
24+
//[mir]~^ ERROR call to unsafe function is unsafe
25+
//[thir]~^^ ERROR call to unsafe function `S::f` is unsafe
26+
f();
27+
//[mir]~^ ERROR call to unsafe function is unsafe
28+
//[thir]~^^ ERROR call to unsafe function `f` is unsafe
2529
}

0 commit comments

Comments
 (0)