Skip to content

Commit 5cf7bbe

Browse files
authored
Rollup merge of rust-lang#119563 - compiler-errors:coroutine-resume, r=oli-obk
Check yield terminator's resume type in borrowck In borrowck, we didn't check that the lifetimes of the `TerminatorKind::Yield`'s `resume_place` were actually compatible with the coroutine's signature. That means that the lifetimes were totally going unchecked. Whoops! This PR implements this checking. Fixes rust-lang#119564 r? types
2 parents cf4cd93 + 1d48f69 commit 5cf7bbe

File tree

12 files changed

+190
-36
lines changed

12 files changed

+190
-36
lines changed

compiler/rustc_borrowck/src/type_check/input_output.rs

+12-21
Original file line numberDiff line numberDiff line change
@@ -94,31 +94,22 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
9494
);
9595
}
9696

97-
debug!(
98-
"equate_inputs_and_outputs: body.yield_ty {:?}, universal_regions.yield_ty {:?}",
99-
body.yield_ty(),
100-
universal_regions.yield_ty
101-
);
102-
103-
// We will not have a universal_regions.yield_ty if we yield (by accident)
104-
// outside of a coroutine and return an `impl Trait`, so emit a span_delayed_bug
105-
// because we don't want to panic in an assert here if we've already got errors.
106-
if body.yield_ty().is_some() != universal_regions.yield_ty.is_some() {
107-
self.tcx().dcx().span_delayed_bug(
108-
body.span,
109-
format!(
110-
"Expected body to have yield_ty ({:?}) iff we have a UR yield_ty ({:?})",
111-
body.yield_ty(),
112-
universal_regions.yield_ty,
113-
),
97+
if let Some(mir_yield_ty) = body.yield_ty() {
98+
let yield_span = body.local_decls[RETURN_PLACE].source_info.span;
99+
self.equate_normalized_input_or_output(
100+
universal_regions.yield_ty.unwrap(),
101+
mir_yield_ty,
102+
yield_span,
114103
);
115104
}
116105

117-
if let (Some(mir_yield_ty), Some(ur_yield_ty)) =
118-
(body.yield_ty(), universal_regions.yield_ty)
119-
{
106+
if let Some(mir_resume_ty) = body.resume_ty() {
120107
let yield_span = body.local_decls[RETURN_PLACE].source_info.span;
121-
self.equate_normalized_input_or_output(ur_yield_ty, mir_yield_ty, yield_span);
108+
self.equate_normalized_input_or_output(
109+
universal_regions.resume_ty.unwrap(),
110+
mir_resume_ty,
111+
yield_span,
112+
);
122113
}
123114

124115
// Return types are a bit more complex. They may contain opaque `impl Trait` types.

compiler/rustc_borrowck/src/type_check/liveness/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ impl<'cx, 'tcx> Visitor<'tcx> for LiveVariablesVisitor<'cx, 'tcx> {
183183
match ty_context {
184184
TyContext::ReturnTy(SourceInfo { span, .. })
185185
| TyContext::YieldTy(SourceInfo { span, .. })
186+
| TyContext::ResumeTy(SourceInfo { span, .. })
186187
| TyContext::UserTy(span)
187188
| TyContext::LocalDecl { source_info: SourceInfo { span, .. }, .. } => {
188189
span_bug!(span, "should not be visiting outside of the CFG: {:?}", ty_context);

compiler/rustc_borrowck/src/type_check/mod.rs

+24-2
Original file line numberDiff line numberDiff line change
@@ -1450,13 +1450,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
14501450
}
14511451
}
14521452
}
1453-
TerminatorKind::Yield { value, .. } => {
1453+
TerminatorKind::Yield { value, resume_arg, .. } => {
14541454
self.check_operand(value, term_location);
14551455

1456-
let value_ty = value.ty(body, tcx);
14571456
match body.yield_ty() {
14581457
None => span_mirbug!(self, term, "yield in non-coroutine"),
14591458
Some(ty) => {
1459+
let value_ty = value.ty(body, tcx);
14601460
if let Err(terr) = self.sub_types(
14611461
value_ty,
14621462
ty,
@@ -1474,6 +1474,28 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
14741474
}
14751475
}
14761476
}
1477+
1478+
match body.resume_ty() {
1479+
None => span_mirbug!(self, term, "yield in non-coroutine"),
1480+
Some(ty) => {
1481+
let resume_ty = resume_arg.ty(body, tcx);
1482+
if let Err(terr) = self.sub_types(
1483+
ty,
1484+
resume_ty.ty,
1485+
term_location.to_locations(),
1486+
ConstraintCategory::Yield,
1487+
) {
1488+
span_mirbug!(
1489+
self,
1490+
term,
1491+
"type of resume place is {:?}, but the resume type is {:?}: {:?}",
1492+
resume_ty,
1493+
ty,
1494+
terr
1495+
);
1496+
}
1497+
}
1498+
}
14771499
}
14781500
}
14791501
}

compiler/rustc_borrowck/src/universal_regions.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ pub struct UniversalRegions<'tcx> {
7676
pub unnormalized_input_tys: &'tcx [Ty<'tcx>],
7777

7878
pub yield_ty: Option<Ty<'tcx>>,
79+
80+
pub resume_ty: Option<Ty<'tcx>>,
7981
}
8082

8183
/// The "defining type" for this MIR. The key feature of the "defining
@@ -525,9 +527,12 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
525527
debug!("build: extern regions = {}..{}", first_extern_index, first_local_index);
526528
debug!("build: local regions = {}..{}", first_local_index, num_universals);
527529

528-
let yield_ty = match defining_ty {
529-
DefiningTy::Coroutine(_, args) => Some(args.as_coroutine().yield_ty()),
530-
_ => None,
530+
let (resume_ty, yield_ty) = match defining_ty {
531+
DefiningTy::Coroutine(_, args) => {
532+
let tys = args.as_coroutine();
533+
(Some(tys.resume_ty()), Some(tys.yield_ty()))
534+
}
535+
_ => (None, None),
531536
};
532537

533538
UniversalRegions {
@@ -541,6 +546,7 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
541546
unnormalized_output_ty: *unnormalized_output_ty,
542547
unnormalized_input_tys,
543548
yield_ty,
549+
resume_ty,
544550
}
545551
}
546552

compiler/rustc_middle/src/mir/mod.rs

+9
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ pub struct CoroutineInfo<'tcx> {
250250
/// The yield type of the function, if it is a coroutine.
251251
pub yield_ty: Option<Ty<'tcx>>,
252252

253+
/// The resume type of the function, if it is a coroutine.
254+
pub resume_ty: Option<Ty<'tcx>>,
255+
253256
/// Coroutine drop glue.
254257
pub coroutine_drop: Option<Body<'tcx>>,
255258

@@ -385,6 +388,7 @@ impl<'tcx> Body<'tcx> {
385388
coroutine: coroutine_kind.map(|coroutine_kind| {
386389
Box::new(CoroutineInfo {
387390
yield_ty: None,
391+
resume_ty: None,
388392
coroutine_drop: None,
389393
coroutine_layout: None,
390394
coroutine_kind,
@@ -551,6 +555,11 @@ impl<'tcx> Body<'tcx> {
551555
self.coroutine.as_ref().and_then(|coroutine| coroutine.yield_ty)
552556
}
553557

558+
#[inline]
559+
pub fn resume_ty(&self) -> Option<Ty<'tcx>> {
560+
self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
561+
}
562+
554563
#[inline]
555564
pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
556565
self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())

compiler/rustc_middle/src/mir/visit.rs

+8
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,12 @@ macro_rules! super_body {
996996
TyContext::YieldTy(SourceInfo::outermost(span))
997997
);
998998
}
999+
if let Some(resume_ty) = $(& $mutability)? gen.resume_ty {
1000+
$self.visit_ty(
1001+
resume_ty,
1002+
TyContext::ResumeTy(SourceInfo::outermost(span))
1003+
);
1004+
}
9991005
}
10001006

10011007
for (bb, data) in basic_blocks_iter!($body, $($mutability, $invalidate)?) {
@@ -1244,6 +1250,8 @@ pub enum TyContext {
12441250

12451251
YieldTy(SourceInfo),
12461252

1253+
ResumeTy(SourceInfo),
1254+
12471255
/// A type found at some location.
12481256
Location(Location),
12491257
}

compiler/rustc_mir_build/src/build/mod.rs

+17-10
Original file line numberDiff line numberDiff line change
@@ -488,17 +488,17 @@ fn construct_fn<'tcx>(
488488

489489
let arguments = &thir.params;
490490

491-
let (yield_ty, return_ty) = if coroutine_kind.is_some() {
491+
let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() {
492492
let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
493493
let coroutine_sig = match coroutine_ty.kind() {
494494
ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
495495
_ => {
496496
span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
497497
}
498498
};
499-
(Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
499+
(Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
500500
} else {
501-
(None, fn_sig.output())
501+
(None, None, fn_sig.output())
502502
};
503503

504504
if let Some(custom_mir_attr) =
@@ -562,9 +562,12 @@ fn construct_fn<'tcx>(
562562
} else {
563563
None
564564
};
565-
if yield_ty.is_some() {
565+
566+
if coroutine_kind.is_some() {
566567
body.coroutine.as_mut().unwrap().yield_ty = yield_ty;
568+
body.coroutine.as_mut().unwrap().resume_ty = resume_ty;
567569
}
570+
568571
body
569572
}
570573

@@ -631,28 +634,29 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
631634
let hir_id = tcx.local_def_id_to_hir_id(def_id);
632635
let coroutine_kind = tcx.coroutine_kind(def_id);
633636

634-
let (inputs, output, yield_ty) = match tcx.def_kind(def_id) {
637+
let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) {
635638
DefKind::Const
636639
| DefKind::AssocConst
637640
| DefKind::AnonConst
638641
| DefKind::InlineConst
639-
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None),
642+
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None),
640643
DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => {
641644
let sig = tcx.liberate_late_bound_regions(
642645
def_id.to_def_id(),
643646
tcx.fn_sig(def_id).instantiate_identity(),
644647
);
645-
(sig.inputs().to_vec(), sig.output(), None)
648+
(sig.inputs().to_vec(), sig.output(), None, None)
646649
}
647650
DefKind::Closure if coroutine_kind.is_some() => {
648651
let coroutine_ty = tcx.type_of(def_id).instantiate_identity();
649652
let ty::Coroutine(_, args) = coroutine_ty.kind() else {
650653
bug!("expected type of coroutine-like closure to be a coroutine")
651654
};
652655
let args = args.as_coroutine();
656+
let resume_ty = args.resume_ty();
653657
let yield_ty = args.yield_ty();
654658
let return_ty = args.return_ty();
655-
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(yield_ty))
659+
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty))
656660
}
657661
DefKind::Closure => {
658662
let closure_ty = tcx.type_of(def_id).instantiate_identity();
@@ -666,7 +670,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
666670
ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
667671
ty::ClosureKind::FnOnce => closure_ty,
668672
};
669-
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None)
673+
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None)
670674
}
671675
dk => bug!("{:?} is not a body: {:?}", def_id, dk),
672676
};
@@ -705,7 +709,10 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
705709
Some(guar),
706710
);
707711

708-
body.coroutine.as_mut().map(|gen| gen.yield_ty = yield_ty);
712+
body.coroutine.as_mut().map(|gen| {
713+
gen.yield_ty = yield_ty;
714+
gen.resume_ty = resume_ty;
715+
});
709716

710717
body
711718
}

compiler/rustc_mir_transform/src/coroutine.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1733,6 +1733,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
17331733
}
17341734

17351735
body.coroutine.as_mut().unwrap().yield_ty = None;
1736+
body.coroutine.as_mut().unwrap().resume_ty = None;
17361737
body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);
17371738

17381739
// Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#![feature(coroutine_trait)]
2+
#![feature(coroutines)]
3+
4+
use std::ops::Coroutine;
5+
6+
struct Contravariant<'a>(fn(&'a ()));
7+
struct Covariant<'a>(fn() -> &'a ());
8+
9+
fn bad1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'short>> {
10+
|_: Covariant<'short>| {
11+
let a: Covariant<'long> = yield ();
12+
//~^ ERROR lifetime may not live long enough
13+
}
14+
}
15+
16+
fn bad2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'long>> {
17+
|_: Contravariant<'long>| {
18+
let a: Contravariant<'short> = yield ();
19+
//~^ ERROR lifetime may not live long enough
20+
}
21+
}
22+
23+
fn good1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'long>> {
24+
|_: Covariant<'long>| {
25+
let a: Covariant<'short> = yield ();
26+
}
27+
}
28+
29+
fn good2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'short>> {
30+
|_: Contravariant<'short>| {
31+
let a: Contravariant<'long> = yield ();
32+
}
33+
}
34+
35+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
error: lifetime may not live long enough
2+
--> $DIR/check-resume-ty-lifetimes-2.rs:11:16
3+
|
4+
LL | fn bad1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'short>> {
5+
| ------ ----- lifetime `'long` defined here
6+
| |
7+
| lifetime `'short` defined here
8+
LL | |_: Covariant<'short>| {
9+
LL | let a: Covariant<'long> = yield ();
10+
| ^^^^^^^^^^^^^^^^ type annotation requires that `'short` must outlive `'long`
11+
|
12+
= help: consider adding the following bound: `'short: 'long`
13+
help: consider adding 'move' keyword before the nested closure
14+
|
15+
LL | move |_: Covariant<'short>| {
16+
| ++++
17+
18+
error: lifetime may not live long enough
19+
--> $DIR/check-resume-ty-lifetimes-2.rs:18:40
20+
|
21+
LL | fn bad2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'long>> {
22+
| ------ ----- lifetime `'long` defined here
23+
| |
24+
| lifetime `'short` defined here
25+
LL | |_: Contravariant<'long>| {
26+
LL | let a: Contravariant<'short> = yield ();
27+
| ^^^^^^^^ yielding this value requires that `'short` must outlive `'long`
28+
|
29+
= help: consider adding the following bound: `'short: 'long`
30+
help: consider adding 'move' keyword before the nested closure
31+
|
32+
LL | move |_: Contravariant<'long>| {
33+
| ++++
34+
35+
error: aborting due to 2 previous errors
36+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#![feature(coroutine_trait)]
2+
#![feature(coroutines)]
3+
#![allow(unused)]
4+
5+
use std::ops::Coroutine;
6+
use std::ops::CoroutineState;
7+
use std::pin::pin;
8+
9+
fn mk_static(s: &str) -> &'static str {
10+
let mut storage: Option<&'static str> = None;
11+
12+
let mut coroutine = pin!(|_: &str| {
13+
let x: &'static str = yield ();
14+
//~^ ERROR lifetime may not live long enough
15+
storage = Some(x);
16+
});
17+
18+
coroutine.as_mut().resume(s);
19+
coroutine.as_mut().resume(s);
20+
21+
storage.unwrap()
22+
}
23+
24+
fn main() {
25+
let s = mk_static(&String::from("hello, world"));
26+
println!("{s}");
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
error: lifetime may not live long enough
2+
--> $DIR/check-resume-ty-lifetimes.rs:13:16
3+
|
4+
LL | fn mk_static(s: &str) -> &'static str {
5+
| - let's call the lifetime of this reference `'1`
6+
...
7+
LL | let x: &'static str = yield ();
8+
| ^^^^^^^^^^^^ type annotation requires that `'1` must outlive `'static`
9+
10+
error: aborting due to 1 previous error
11+

0 commit comments

Comments
 (0)