@@ -460,6 +460,104 @@ fn replace_local<'tcx>(
460
460
new_local
461
461
}
462
462
463
+ /// Transforms the `body` of the generator applying the following transforms:
464
+ ///
465
+ /// - Eliminates all the `get_context` calls that async lowering created.
466
+ /// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
467
+ ///
468
+ /// The `Local`s that have their types replaced are:
469
+ /// - The `resume` argument itself.
470
+ /// - The argument to `get_context`.
471
+ /// - The yielded value of a `yield`.
472
+ ///
473
+ /// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
474
+ /// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
475
+ ///
476
+ /// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
477
+ /// but rather directly use `&mut Context<'_>`, however that would currently
478
+ /// lead to higher-kinded lifetime errors.
479
+ /// See <https://github.com/rust-lang/rust/issues/105501>.
480
+ ///
481
+ /// The async lowering step and the type / lifetime inference / checking are
482
+ /// still using the `ResumeTy` indirection for the time being, and that indirection
483
+ /// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
484
+ fn transform_async_context < ' tcx > ( tcx : TyCtxt < ' tcx > , body : & mut Body < ' tcx > ) {
485
+ let context_mut_ref = tcx. mk_task_context ( ) ;
486
+
487
+ // replace the type of the `resume` argument
488
+ replace_resume_ty_local ( tcx, body, Local :: new ( 2 ) , context_mut_ref) ;
489
+
490
+ let get_context_def_id = tcx. require_lang_item ( LangItem :: GetContext , None ) ;
491
+
492
+ for bb in BasicBlock :: new ( 0 ) ..body. basic_blocks . next_index ( ) {
493
+ let bb_data = & body[ bb] ;
494
+ if bb_data. is_cleanup {
495
+ continue ;
496
+ }
497
+
498
+ match & bb_data. terminator ( ) . kind {
499
+ TerminatorKind :: Call { func, .. } => {
500
+ let func_ty = func. ty ( body, tcx) ;
501
+ if let ty:: FnDef ( def_id, _) = * func_ty. kind ( ) {
502
+ if def_id == get_context_def_id {
503
+ let local = eliminate_get_context_call ( & mut body[ bb] ) ;
504
+ replace_resume_ty_local ( tcx, body, local, context_mut_ref) ;
505
+ }
506
+ } else {
507
+ continue ;
508
+ }
509
+ }
510
+ TerminatorKind :: Yield { resume_arg, .. } => {
511
+ replace_resume_ty_local ( tcx, body, resume_arg. local , context_mut_ref) ;
512
+ }
513
+ _ => { }
514
+ }
515
+ }
516
+ }
517
+
518
+ fn eliminate_get_context_call < ' tcx > ( bb_data : & mut BasicBlockData < ' tcx > ) -> Local {
519
+ let terminator = bb_data. terminator . take ( ) . unwrap ( ) ;
520
+ if let TerminatorKind :: Call { mut args, destination, target, .. } = terminator. kind {
521
+ let arg = args. pop ( ) . unwrap ( ) ;
522
+ let local = arg. place ( ) . unwrap ( ) . local ;
523
+
524
+ let arg = Rvalue :: Use ( arg) ;
525
+ let assign = Statement {
526
+ source_info : terminator. source_info ,
527
+ kind : StatementKind :: Assign ( Box :: new ( ( destination, arg) ) ) ,
528
+ } ;
529
+ bb_data. statements . push ( assign) ;
530
+ bb_data. terminator = Some ( Terminator {
531
+ source_info : terminator. source_info ,
532
+ kind : TerminatorKind :: Goto { target : target. unwrap ( ) } ,
533
+ } ) ;
534
+ local
535
+ } else {
536
+ bug ! ( ) ;
537
+ }
538
+ }
539
+
540
+ #[ cfg_attr( not( debug_assertions) , allow( unused) ) ]
541
+ fn replace_resume_ty_local < ' tcx > (
542
+ tcx : TyCtxt < ' tcx > ,
543
+ body : & mut Body < ' tcx > ,
544
+ local : Local ,
545
+ context_mut_ref : Ty < ' tcx > ,
546
+ ) {
547
+ let local_ty = std:: mem:: replace ( & mut body. local_decls [ local] . ty , context_mut_ref) ;
548
+ // We have to replace the `ResumeTy` that is used for type and borrow checking
549
+ // with `&mut Context<'_>` in MIR.
550
+ #[ cfg( debug_assertions) ]
551
+ {
552
+ if let ty:: Adt ( resume_ty_adt, _) = local_ty. kind ( ) {
553
+ let expected_adt = tcx. adt_def ( tcx. require_lang_item ( LangItem :: ResumeTy , None ) ) ;
554
+ assert_eq ! ( * resume_ty_adt, expected_adt) ;
555
+ } else {
556
+ panic ! ( "expected `ResumeTy`, found `{:?}`" , local_ty) ;
557
+ } ;
558
+ }
559
+ }
560
+
463
561
struct LivenessInfo {
464
562
/// Which locals are live across any suspension point.
465
563
saved_locals : GeneratorSavedLocals ,
@@ -1283,13 +1381,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1283
1381
}
1284
1382
} ;
1285
1383
1286
- let is_async_kind = body. generator_kind ( ) . unwrap ( ) != GeneratorKind :: Gen ;
1384
+ let is_async_kind = matches ! ( body. generator_kind( ) , Some ( GeneratorKind :: Async ( _ ) ) ) ;
1287
1385
let ( state_adt_ref, state_substs) = if is_async_kind {
1288
1386
// Compute Poll<return_ty>
1289
- let state_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
1290
- let state_adt_ref = tcx. adt_def ( state_did ) ;
1291
- let state_substs = tcx. intern_substs ( & [ body. return_ty ( ) . into ( ) ] ) ;
1292
- ( state_adt_ref , state_substs )
1387
+ let poll_did = tcx. require_lang_item ( LangItem :: Poll , None ) ;
1388
+ let poll_adt_ref = tcx. adt_def ( poll_did ) ;
1389
+ let poll_substs = tcx. intern_substs ( & [ body. return_ty ( ) . into ( ) ] ) ;
1390
+ ( poll_adt_ref , poll_substs )
1293
1391
} else {
1294
1392
// Compute GeneratorState<yield_ty, return_ty>
1295
1393
let state_did = tcx. require_lang_item ( LangItem :: GeneratorState , None ) ;
@@ -1303,13 +1401,19 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1303
1401
// RETURN_PLACE then is a fresh unused local with type ret_ty.
1304
1402
let new_ret_local = replace_local ( RETURN_PLACE , ret_ty, body, tcx) ;
1305
1403
1404
+ // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1405
+ if is_async_kind {
1406
+ transform_async_context ( tcx, body) ;
1407
+ }
1408
+
1306
1409
// We also replace the resume argument and insert an `Assign`.
1307
1410
// This is needed because the resume argument `_2` might be live across a `yield`, in which
1308
1411
// case there is no `Assign` to it that the transform can turn into a store to the generator
1309
1412
// state. After the yield the slot in the generator state would then be uninitialized.
1310
1413
let resume_local = Local :: new ( 2 ) ;
1311
- let new_resume_local =
1312
- replace_local ( resume_local, body. local_decls [ resume_local] . ty , body, tcx) ;
1414
+ let resume_ty =
1415
+ if is_async_kind { tcx. mk_task_context ( ) } else { body. local_decls [ resume_local] . ty } ;
1416
+ let new_resume_local = replace_local ( resume_local, resume_ty, body, tcx) ;
1313
1417
1314
1418
// When first entering the generator, move the resume argument into its new local.
1315
1419
let source_info = SourceInfo :: outermost ( body. span ) ;
0 commit comments