Skip to content

Commit 1db48e2

Browse files
committed
Transform async ResumeTy in generator transform
- Eliminates all the `get_context` calls that async lowering created. - Replace all `Local` `ResumeTy` types with `&mut Context<'_>`. The `Local`s that have their types replaced are: - The `resume` argument itself. - The argument to `get_context`. - The yielded value of a `yield`. The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the `get_context` function is being used to convert that back to a `&mut Context<'_>`. Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection, but rather directly use `&mut Context<'_>`, however that would currently lead to higher-kinded lifetime errors. See <#105501>. The async lowering step and the type / lifetime inference / checking are still using the `ResumeTy` indirection for the time being, and that indirection is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
1 parent 85357e3 commit 1db48e2

File tree

10 files changed

+644
-14
lines changed

10 files changed

+644
-14
lines changed

compiler/rustc_hir/src/lang_items.rs

+1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ language_item_table! {
291291
IdentityFuture, sym::identity_future, identity_future_fn, Target::Fn, GenericRequirement::None;
292292
GetContext, sym::get_context, get_context_fn, Target::Fn, GenericRequirement::None;
293293

294+
Context, sym::Context, context, Target::Struct, GenericRequirement::None;
294295
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
295296

296297
FromFrom, sym::from, from_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;

compiler/rustc_middle/src/ty/context.rs

+9
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,15 @@ impl<'tcx> TyCtxt<'tcx> {
19521952
self.mk_ty(GeneratorWitness(types))
19531953
}
19541954

1955+
/// Creates a `&mut Context<'_>` [`Ty`] with erased lifetimes.
1956+
pub fn mk_task_context(self) -> Ty<'tcx> {
1957+
let context_did = self.require_lang_item(LangItem::Context, None);
1958+
let context_adt_ref = self.adt_def(context_did);
1959+
let context_substs = self.intern_substs(&[self.lifetimes.re_erased.into()]);
1960+
let context_ty = self.mk_adt(context_adt_ref, context_substs);
1961+
self.mk_mut_ref(self.lifetimes.re_erased, context_ty)
1962+
}
1963+
19551964
#[inline]
19561965
pub fn mk_ty_var(self, v: TyVid) -> Ty<'tcx> {
19571966
self.mk_ty_infer(TyVar(v))

compiler/rustc_mir_transform/src/generator.rs

+111-7
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,104 @@ fn replace_local<'tcx>(
460460
new_local
461461
}
462462

463+
/// Transforms the `body` of the generator applying the following transforms:
464+
///
465+
/// - Eliminates all the `get_context` calls that async lowering created.
466+
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
467+
///
468+
/// The `Local`s that have their types replaced are:
469+
/// - The `resume` argument itself.
470+
/// - The argument to `get_context`.
471+
/// - The yielded value of a `yield`.
472+
///
473+
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
474+
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
475+
///
476+
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
477+
/// but rather directly use `&mut Context<'_>`, however that would currently
478+
/// lead to higher-kinded lifetime errors.
479+
/// See <https://github.com/rust-lang/rust/issues/105501>.
480+
///
481+
/// The async lowering step and the type / lifetime inference / checking are
482+
/// still using the `ResumeTy` indirection for the time being, and that indirection
483+
/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
484+
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
485+
let context_mut_ref = tcx.mk_task_context();
486+
487+
// replace the type of the `resume` argument
488+
replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref);
489+
490+
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
491+
492+
for bb in BasicBlock::new(0)..body.basic_blocks.next_index() {
493+
let bb_data = &body[bb];
494+
if bb_data.is_cleanup {
495+
continue;
496+
}
497+
498+
match &bb_data.terminator().kind {
499+
TerminatorKind::Call { func, .. } => {
500+
let func_ty = func.ty(body, tcx);
501+
if let ty::FnDef(def_id, _) = *func_ty.kind() {
502+
if def_id == get_context_def_id {
503+
let local = eliminate_get_context_call(&mut body[bb]);
504+
replace_resume_ty_local(tcx, body, local, context_mut_ref);
505+
}
506+
} else {
507+
continue;
508+
}
509+
}
510+
TerminatorKind::Yield { resume_arg, .. } => {
511+
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
512+
}
513+
_ => {}
514+
}
515+
}
516+
}
517+
518+
fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
519+
let terminator = bb_data.terminator.take().unwrap();
520+
if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind {
521+
let arg = args.pop().unwrap();
522+
let local = arg.place().unwrap().local;
523+
524+
let arg = Rvalue::Use(arg);
525+
let assign = Statement {
526+
source_info: terminator.source_info,
527+
kind: StatementKind::Assign(Box::new((destination, arg))),
528+
};
529+
bb_data.statements.push(assign);
530+
bb_data.terminator = Some(Terminator {
531+
source_info: terminator.source_info,
532+
kind: TerminatorKind::Goto { target: target.unwrap() },
533+
});
534+
local
535+
} else {
536+
bug!();
537+
}
538+
}
539+
540+
#[cfg_attr(not(debug_assertions), allow(unused))]
541+
fn replace_resume_ty_local<'tcx>(
542+
tcx: TyCtxt<'tcx>,
543+
body: &mut Body<'tcx>,
544+
local: Local,
545+
context_mut_ref: Ty<'tcx>,
546+
) {
547+
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
548+
// We have to replace the `ResumeTy` that is used for type and borrow checking
549+
// with `&mut Context<'_>` in MIR.
550+
#[cfg(debug_assertions)]
551+
{
552+
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
553+
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
554+
assert_eq!(*resume_ty_adt, expected_adt);
555+
} else {
556+
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
557+
};
558+
}
559+
}
560+
463561
struct LivenessInfo {
464562
/// Which locals are live across any suspension point.
465563
saved_locals: GeneratorSavedLocals,
@@ -1283,13 +1381,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
12831381
}
12841382
};
12851383

1286-
let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen;
1384+
let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_)));
12871385
let (state_adt_ref, state_substs) = if is_async_kind {
12881386
// Compute Poll<return_ty>
1289-
let state_did = tcx.require_lang_item(LangItem::Poll, None);
1290-
let state_adt_ref = tcx.adt_def(state_did);
1291-
let state_substs = tcx.intern_substs(&[body.return_ty().into()]);
1292-
(state_adt_ref, state_substs)
1387+
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
1388+
let poll_adt_ref = tcx.adt_def(poll_did);
1389+
let poll_substs = tcx.intern_substs(&[body.return_ty().into()]);
1390+
(poll_adt_ref, poll_substs)
12931391
} else {
12941392
// Compute GeneratorState<yield_ty, return_ty>
12951393
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
@@ -1303,13 +1401,19 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
13031401
// RETURN_PLACE then is a fresh unused local with type ret_ty.
13041402
let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
13051403

1404+
// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1405+
if is_async_kind {
1406+
transform_async_context(tcx, body);
1407+
}
1408+
13061409
// We also replace the resume argument and insert an `Assign`.
13071410
// This is needed because the resume argument `_2` might be live across a `yield`, in which
13081411
// case there is no `Assign` to it that the transform can turn into a store to the generator
13091412
// state. After the yield the slot in the generator state would then be uninitialized.
13101413
let resume_local = Local::new(2);
1311-
let new_resume_local =
1312-
replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx);
1414+
let resume_ty =
1415+
if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty };
1416+
let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
13131417

13141418
// When first entering the generator, move the resume argument into its new local.
13151419
let source_info = SourceInfo::outermost(body.span);

compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ symbols! {
164164
Capture,
165165
Center,
166166
Clone,
167+
Context,
167168
Continue,
168169
Copy,
169170
Count,

compiler/rustc_ty_utils/src/abi.rs

+27-7
Original file line numberDiff line numberDiff line change
@@ -108,21 +108,41 @@ fn fn_sig_for_fn_abi<'tcx>(
108108
// `Generator::resume(...) -> GeneratorState` function in case we
109109
// have an ordinary generator, or the `Future::poll(...) -> Poll`
110110
// function in case this is a special generator backing an async construct.
111-
let ret_ty = if tcx.generator_is_async(did) {
112-
let state_did = tcx.require_lang_item(LangItem::Poll, None);
113-
let state_adt_ref = tcx.adt_def(state_did);
114-
let state_substs = tcx.intern_substs(&[sig.return_ty.into()]);
115-
tcx.mk_adt(state_adt_ref, state_substs)
111+
let (resume_ty, ret_ty) = if tcx.generator_is_async(did) {
112+
// The signature should be `Future::poll(_, &mut Context<'_>) -> Poll<Output>`
113+
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
114+
let poll_adt_ref = tcx.adt_def(poll_did);
115+
let poll_substs = tcx.intern_substs(&[sig.return_ty.into()]);
116+
let ret_ty = tcx.mk_adt(poll_adt_ref, poll_substs);
117+
118+
// We have to replace the `ResumeTy` that is used for type and borrow checking
119+
// with `&mut Context<'_>` which is used in codegen.
120+
#[cfg(debug_assertions)]
121+
{
122+
if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
123+
let expected_adt =
124+
tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
125+
assert_eq!(*resume_ty_adt, expected_adt);
126+
} else {
127+
panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
128+
};
129+
}
130+
let context_mut_ref = tcx.mk_task_context();
131+
132+
(context_mut_ref, ret_ty)
116133
} else {
134+
// The signature should be `Generator::resume(_, Resume) -> GeneratorState<Yield, Return>`
117135
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
118136
let state_adt_ref = tcx.adt_def(state_did);
119137
let state_substs = tcx.intern_substs(&[sig.yield_ty.into(), sig.return_ty.into()]);
120-
tcx.mk_adt(state_adt_ref, state_substs)
138+
let ret_ty = tcx.mk_adt(state_adt_ref, state_substs);
139+
140+
(sig.resume_ty, ret_ty)
121141
};
122142

123143
ty::Binder::bind_with_vars(
124144
tcx.mk_fn_sig(
125-
[env_ty, sig.resume_ty].iter(),
145+
[env_ty, resume_ty].iter(),
126146
&ret_ty,
127147
false,
128148
hir::Unsafety::Normal,

library/core/src/future/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ pub unsafe fn get_context<'a, 'b>(cx: ResumeTy) -> &'a mut Context<'b> {
112112
unsafe { &mut *cx.0.as_ptr().cast() }
113113
}
114114

115+
// FIXME(swatinem): This fn is currently needed to work around shortcomings
116+
// in type and lifetime inference.
117+
// See the comment at the bottom of `LoweringContext::make_async_expr` and
118+
// <https://github.com/rust-lang/rust/issues/104826>.
115119
#[doc(hidden)]
116120
#[unstable(feature = "gen_future", issue = "50547")]
117121
#[inline]

library/core/src/task/wake.rs

+1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ impl RawWakerVTable {
174174
/// Currently, `Context` only serves to provide access to a [`&Waker`](Waker)
175175
/// which can be used to wake the current task.
176176
#[stable(feature = "futures_api", since = "1.36.0")]
177+
#[cfg_attr(not(bootstrap), lang = "Context")]
177178
pub struct Context<'a> {
178179
waker: &'a Waker,
179180
// Ensure we future-proof against variance changes by forcing
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// MIR for `a::{closure#0}` 0 generator_resume
2+
/* generator_layout = GeneratorLayout {
3+
field_tys: {},
4+
variant_fields: {
5+
Unresumed(0): [],
6+
Returned (1): [],
7+
Panicked (2): [],
8+
},
9+
storage_conflicts: BitMatrix(0x0) {},
10+
} */
11+
12+
fn a::{closure#0}(_1: Pin<&mut [async fn body@$DIR/async_await.rs:10:14: 10:16]>, _2: &mut Context<'_>) -> Poll<()> {
13+
debug _task_context => _4; // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
14+
let mut _0: std::task::Poll<()>; // return place in scope 0 at $DIR/async_await.rs:+0:14: +0:16
15+
let mut _3: (); // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
16+
let mut _4: &mut std::task::Context<'_>; // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
17+
let mut _5: u32; // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
18+
19+
bb0: {
20+
_5 = discriminant((*(_1.0: &mut [async fn body@$DIR/async_await.rs:10:14: 10:16]))); // scope 0 at $DIR/async_await.rs:+0:14: +0:16
21+
switchInt(move _5) -> [0: bb1, 1: bb2, otherwise: bb3]; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
22+
}
23+
24+
bb1: {
25+
_4 = move _2; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
26+
_3 = const (); // scope 0 at $DIR/async_await.rs:+0:14: +0:16
27+
Deinit(_0); // scope 0 at $DIR/async_await.rs:+0:16: +0:16
28+
((_0 as Ready).0: ()) = move _3; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
29+
discriminant(_0) = 0; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
30+
discriminant((*(_1.0: &mut [async fn body@$DIR/async_await.rs:10:14: 10:16]))) = 1; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
31+
return; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
32+
}
33+
34+
bb2: {
35+
assert(const false, "`async fn` resumed after completion") -> bb2; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
36+
}
37+
38+
bb3: {
39+
unreachable; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
40+
}
41+
}

0 commit comments

Comments
 (0)