@@ -470,6 +470,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
470
470
}
471
471
}
472
472
473
+ /// Lower an `async` construct to a generator that is then wrapped so it implements `Future`.
474
+ ///
475
+ /// This results in:
476
+ ///
477
+ /// ```text
478
+ /// std::future::from_generator(static move? |_task_context| -> <ret_ty> {
479
+ /// <body>
480
+ /// })
481
+ /// ```
473
482
pub ( super ) fn make_async_expr (
474
483
& mut self ,
475
484
capture_clause : CaptureBy ,
@@ -480,17 +489,42 @@ impl<'hir> LoweringContext<'_, 'hir> {
480
489
body : impl FnOnce ( & mut Self ) -> hir:: Expr < ' hir > ,
481
490
) -> hir:: ExprKind < ' hir > {
482
491
let output = match ret_ty {
483
- Some ( ty) => FnRetTy :: Ty ( ty ) ,
484
- None => FnRetTy :: Default ( span) ,
492
+ Some ( ty) => hir :: FnRetTy :: Return ( self . lower_ty ( & ty , ImplTraitContext :: disallowed ( ) ) ) ,
493
+ None => hir :: FnRetTy :: DefaultReturn ( span) ,
485
494
} ;
486
- let ast_decl = FnDecl { inputs : vec ! [ ] , output } ;
487
- let decl = self . lower_fn_decl ( & ast_decl, None , /* impl trait allowed */ false , None ) ;
488
- let body_id = self . lower_fn_body ( & ast_decl, |this| {
495
+
496
+ // Resume argument type. We let the compiler infer this to simplify the lowering. It is
497
+ // fully constrained by `future::from_generator`.
498
+ let input_ty = hir:: Ty { hir_id : self . next_id ( ) , kind : hir:: TyKind :: Infer , span } ;
499
+
500
+ // The closure/generator `FnDecl` takes a single (resume) argument of type `input_ty`.
501
+ let decl = self . arena . alloc ( hir:: FnDecl {
502
+ inputs : arena_vec ! [ self ; input_ty] ,
503
+ output,
504
+ c_variadic : false ,
505
+ implicit_self : hir:: ImplicitSelfKind :: None ,
506
+ } ) ;
507
+
508
+ // Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
509
+ let ( pat, task_context_hid) = self . pat_ident_binding_mode (
510
+ span,
511
+ Ident :: with_dummy_span ( sym:: _task_context) ,
512
+ hir:: BindingAnnotation :: Mutable ,
513
+ ) ;
514
+ let param = hir:: Param { attrs : & [ ] , hir_id : self . next_id ( ) , pat, span } ;
515
+ let params = arena_vec ! [ self ; param] ;
516
+
517
+ let body_id = self . lower_body ( move |this| {
489
518
this. generator_kind = Some ( hir:: GeneratorKind :: Async ( async_gen_kind) ) ;
490
- body ( this)
519
+
520
+ let old_ctx = this. task_context ;
521
+ this. task_context = Some ( task_context_hid) ;
522
+ let res = body ( this) ;
523
+ this. task_context = old_ctx;
524
+ ( params, res)
491
525
} ) ;
492
526
493
- // `static || -> <ret_ty> { body }`:
527
+ // `static |_task_context | -> <ret_ty> { body }`:
494
528
let generator_kind = hir:: ExprKind :: Closure (
495
529
capture_clause,
496
530
decl,
@@ -523,13 +557,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
523
557
/// ```rust
524
558
/// match <expr> {
525
559
/// mut pinned => loop {
526
- /// match ::std::future::poll_with_tls_context(unsafe {
527
- /// <::std::pin::Pin>::new_unchecked(&mut pinned)
528
- /// }) {
560
+ /// match unsafe { ::std::future::poll_with_context(
561
+ /// <::std::pin::Pin>::new_unchecked(&mut pinned),
562
+ /// task_context,
563
+ /// ) } {
529
564
/// ::std::task::Poll::Ready(result) => break result,
530
565
/// ::std::task::Poll::Pending => {}
531
566
/// }
532
- /// yield ();
567
+ /// task_context = yield ();
533
568
/// }
534
569
/// }
535
570
/// ```
@@ -561,12 +596,23 @@ impl<'hir> LoweringContext<'_, 'hir> {
561
596
let ( pinned_pat, pinned_pat_hid) =
562
597
self . pat_ident_binding_mode ( span, pinned_ident, hir:: BindingAnnotation :: Mutable ) ;
563
598
564
- // ::std::future::poll_with_tls_context(unsafe {
565
- // ::std::pin::Pin::new_unchecked(&mut pinned)
566
- // })`
599
+ let task_context_ident = Ident :: with_dummy_span ( sym:: _task_context) ;
600
+
601
+ // unsafe {
602
+ // ::std::future::poll_with_context(
603
+ // ::std::pin::Pin::new_unchecked(&mut pinned),
604
+ // task_context,
605
+ // )
606
+ // }
567
607
let poll_expr = {
568
608
let pinned = self . expr_ident ( span, pinned_ident, pinned_pat_hid) ;
569
609
let ref_mut_pinned = self . expr_mut_addr_of ( span, pinned) ;
610
+ let task_context = if let Some ( task_context_hid) = self . task_context {
611
+ self . expr_ident_mut ( span, task_context_ident, task_context_hid)
612
+ } else {
613
+ // Use of `await` outside of an async context, we cannot use `task_context` here.
614
+ self . expr_err ( span)
615
+ } ;
570
616
let pin_ty_id = self . next_id ( ) ;
571
617
let new_unchecked_expr_kind = self . expr_call_std_assoc_fn (
572
618
pin_ty_id,
@@ -575,14 +621,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
575
621
"new_unchecked" ,
576
622
arena_vec ! [ self ; ref_mut_pinned] ,
577
623
) ;
578
- let new_unchecked =
579
- self . arena . alloc ( self . expr ( span, new_unchecked_expr_kind, ThinVec :: new ( ) ) ) ;
580
- let unsafe_expr = self . expr_unsafe ( new_unchecked) ;
581
- self . expr_call_std_path (
624
+ let new_unchecked = self . expr ( span, new_unchecked_expr_kind, ThinVec :: new ( ) ) ;
625
+ let call = self . expr_call_std_path (
582
626
gen_future_span,
583
- & [ sym:: future, sym:: poll_with_tls_context] ,
584
- arena_vec ! [ self ; unsafe_expr] ,
585
- )
627
+ & [ sym:: future, sym:: poll_with_context] ,
628
+ arena_vec ! [ self ; new_unchecked, task_context] ,
629
+ ) ;
630
+ self . arena . alloc ( self . expr_unsafe ( call) )
586
631
} ;
587
632
588
633
// `::std::task::Poll::Ready(result) => break result`
@@ -622,14 +667,26 @@ impl<'hir> LoweringContext<'_, 'hir> {
622
667
self . stmt_expr ( span, match_expr)
623
668
} ;
624
669
670
+ // task_context = yield ();
625
671
let yield_stmt = {
626
672
let unit = self . expr_unit ( span) ;
627
673
let yield_expr = self . expr (
628
674
span,
629
675
hir:: ExprKind :: Yield ( unit, hir:: YieldSource :: Await ) ,
630
676
ThinVec :: new ( ) ,
631
677
) ;
632
- self . stmt_expr ( span, yield_expr)
678
+ let yield_expr = self . arena . alloc ( yield_expr) ;
679
+
680
+ if let Some ( task_context_hid) = self . task_context {
681
+ let lhs = self . expr_ident ( span, task_context_ident, task_context_hid) ;
682
+ let assign =
683
+ self . expr ( span, hir:: ExprKind :: Assign ( lhs, yield_expr, span) , AttrVec :: new ( ) ) ;
684
+ self . stmt_expr ( span, assign)
685
+ } else {
686
+ // Use of `await` outside of an async context. Return `yield_expr` so that we can
687
+ // proceed with type checking.
688
+ self . stmt ( span, hir:: StmtKind :: Semi ( yield_expr) )
689
+ }
633
690
} ;
634
691
635
692
let loop_block = self . block_all ( span, arena_vec ! [ self ; inner_match_stmt, yield_stmt] , None ) ;
0 commit comments