Skip to content

Commit ef7c8a1

Browse files
authored
Rollup merge of #69033 - jonas-schievink:resume-with-context, r=tmandry
Use generator resume arguments in the async/await lowering This removes the TLS requirement from async/await and enables it in `#![no_std]` crates. Closes #56974 I'm not confident the HIR lowering is completely correct, there seem to be quite a few undocumented invariants in there. The `async-std` and tokio test suites are passing with these changes though.
2 parents 4b91729 + db0126a commit ef7c8a1

File tree

9 files changed

+204
-36
lines changed

9 files changed

+204
-36
lines changed

src/libcore/future/mod.rs

+78
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,84 @@
22

33
//! Asynchronous values.
44
5+
#[cfg(not(bootstrap))]
6+
use crate::{
7+
ops::{Generator, GeneratorState},
8+
pin::Pin,
9+
ptr::NonNull,
10+
task::{Context, Poll},
11+
};
12+
513
mod future;
614
#[stable(feature = "futures_api", since = "1.36.0")]
715
pub use self::future::Future;
16+
17+
/// This type is needed because:
18+
///
19+
/// a) Generators cannot implement `for<'a, 'b> Generator<&'a mut Context<'b>>`, so we need to pass
20+
/// a raw pointer (see https://github.com/rust-lang/rust/issues/68923).
21+
/// b) Raw pointers and `NonNull` aren't `Send` or `Sync`, so that would make every single future
22+
/// non-Send/Sync as well, and we don't want that.
23+
///
24+
/// It also simplifies the HIR lowering of `.await`.
25+
#[doc(hidden)]
26+
#[unstable(feature = "gen_future", issue = "50547")]
27+
#[cfg(not(bootstrap))]
28+
#[derive(Debug, Copy, Clone)]
29+
pub struct ResumeTy(NonNull<Context<'static>>);
30+
31+
#[unstable(feature = "gen_future", issue = "50547")]
32+
#[cfg(not(bootstrap))]
33+
unsafe impl Send for ResumeTy {}
34+
35+
#[unstable(feature = "gen_future", issue = "50547")]
36+
#[cfg(not(bootstrap))]
37+
unsafe impl Sync for ResumeTy {}
38+
39+
/// Wrap a generator in a future.
40+
///
41+
/// This function returns a `GenFuture` underneath, but hides it in `impl Trait` to give
42+
/// better error messages (`impl Future` rather than `GenFuture<[closure.....]>`).
43+
// This is `const` to avoid extra errors after we recover from `const async fn`
44+
#[doc(hidden)]
45+
#[unstable(feature = "gen_future", issue = "50547")]
46+
#[cfg(not(bootstrap))]
47+
#[inline]
48+
pub const fn from_generator<T>(gen: T) -> impl Future<Output = T::Return>
49+
where
50+
T: Generator<ResumeTy, Yield = ()>,
51+
{
52+
struct GenFuture<T: Generator<ResumeTy, Yield = ()>>(T);
53+
54+
// We rely on the fact that async/await futures are immovable in order to create
55+
// self-referential borrows in the underlying generator.
56+
impl<T: Generator<ResumeTy, Yield = ()>> !Unpin for GenFuture<T> {}
57+
58+
impl<T: Generator<ResumeTy, Yield = ()>> Future for GenFuture<T> {
59+
type Output = T::Return;
60+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
61+
// Safety: Safe because we're !Unpin + !Drop, and this is just a field projection.
62+
let gen = unsafe { Pin::map_unchecked_mut(self, |s| &mut s.0) };
63+
64+
// Resume the generator, turning the `&mut Context` into a `NonNull` raw pointer. The
65+
// `.await` lowering will safely cast that back to a `&mut Context`.
66+
match gen.resume(ResumeTy(NonNull::from(cx).cast::<Context<'static>>())) {
67+
GeneratorState::Yielded(()) => Poll::Pending,
68+
GeneratorState::Complete(x) => Poll::Ready(x),
69+
}
70+
}
71+
}
72+
73+
GenFuture(gen)
74+
}
75+
76+
#[doc(hidden)]
77+
#[unstable(feature = "gen_future", issue = "50547")]
78+
#[cfg(not(bootstrap))]
79+
#[inline]
80+
pub unsafe fn poll_with_context<F>(f: Pin<&mut F>, mut cx: ResumeTy) -> Poll<F::Output>
81+
where
82+
F: Future,
83+
{
84+
F::poll(f, cx.0.as_mut())
85+
}

src/librustc_ast_lowering/expr.rs

+79-22
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
470470
}
471471
}
472472

473+
/// Lower an `async` construct to a generator that is then wrapped so it implements `Future`.
474+
///
475+
/// This results in:
476+
///
477+
/// ```text
478+
/// std::future::from_generator(static move? |_task_context| -> <ret_ty> {
479+
/// <body>
480+
/// })
481+
/// ```
473482
pub(super) fn make_async_expr(
474483
&mut self,
475484
capture_clause: CaptureBy,
@@ -480,17 +489,42 @@ impl<'hir> LoweringContext<'_, 'hir> {
480489
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
481490
) -> hir::ExprKind<'hir> {
482491
let output = match ret_ty {
483-
Some(ty) => FnRetTy::Ty(ty),
484-
None => FnRetTy::Default(span),
492+
Some(ty) => hir::FnRetTy::Return(self.lower_ty(&ty, ImplTraitContext::disallowed())),
493+
None => hir::FnRetTy::DefaultReturn(span),
485494
};
486-
let ast_decl = FnDecl { inputs: vec![], output };
487-
let decl = self.lower_fn_decl(&ast_decl, None, /* impl trait allowed */ false, None);
488-
let body_id = self.lower_fn_body(&ast_decl, |this| {
495+
496+
// Resume argument type. We let the compiler infer this to simplify the lowering. It is
497+
// fully constrained by `future::from_generator`.
498+
let input_ty = hir::Ty { hir_id: self.next_id(), kind: hir::TyKind::Infer, span };
499+
500+
// The closure/generator `FnDecl` takes a single (resume) argument of type `input_ty`.
501+
let decl = self.arena.alloc(hir::FnDecl {
502+
inputs: arena_vec![self; input_ty],
503+
output,
504+
c_variadic: false,
505+
implicit_self: hir::ImplicitSelfKind::None,
506+
});
507+
508+
// Lower the argument pattern/ident. The ident is used again in the `.await` lowering.
509+
let (pat, task_context_hid) = self.pat_ident_binding_mode(
510+
span,
511+
Ident::with_dummy_span(sym::_task_context),
512+
hir::BindingAnnotation::Mutable,
513+
);
514+
let param = hir::Param { attrs: &[], hir_id: self.next_id(), pat, span };
515+
let params = arena_vec![self; param];
516+
517+
let body_id = self.lower_body(move |this| {
489518
this.generator_kind = Some(hir::GeneratorKind::Async(async_gen_kind));
490-
body(this)
519+
520+
let old_ctx = this.task_context;
521+
this.task_context = Some(task_context_hid);
522+
let res = body(this);
523+
this.task_context = old_ctx;
524+
(params, res)
491525
});
492526

493-
// `static || -> <ret_ty> { body }`:
527+
// `static |_task_context| -> <ret_ty> { body }`:
494528
let generator_kind = hir::ExprKind::Closure(
495529
capture_clause,
496530
decl,
@@ -523,13 +557,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
523557
/// ```rust
524558
/// match <expr> {
525559
/// mut pinned => loop {
526-
/// match ::std::future::poll_with_tls_context(unsafe {
527-
/// <::std::pin::Pin>::new_unchecked(&mut pinned)
528-
/// }) {
560+
/// match unsafe { ::std::future::poll_with_context(
561+
/// <::std::pin::Pin>::new_unchecked(&mut pinned),
562+
/// task_context,
563+
/// ) } {
529564
/// ::std::task::Poll::Ready(result) => break result,
530565
/// ::std::task::Poll::Pending => {}
531566
/// }
532-
/// yield ();
567+
/// task_context = yield ();
533568
/// }
534569
/// }
535570
/// ```
@@ -561,12 +596,23 @@ impl<'hir> LoweringContext<'_, 'hir> {
561596
let (pinned_pat, pinned_pat_hid) =
562597
self.pat_ident_binding_mode(span, pinned_ident, hir::BindingAnnotation::Mutable);
563598

564-
// ::std::future::poll_with_tls_context(unsafe {
565-
// ::std::pin::Pin::new_unchecked(&mut pinned)
566-
// })`
599+
let task_context_ident = Ident::with_dummy_span(sym::_task_context);
600+
601+
// unsafe {
602+
// ::std::future::poll_with_context(
603+
// ::std::pin::Pin::new_unchecked(&mut pinned),
604+
// task_context,
605+
// )
606+
// }
567607
let poll_expr = {
568608
let pinned = self.expr_ident(span, pinned_ident, pinned_pat_hid);
569609
let ref_mut_pinned = self.expr_mut_addr_of(span, pinned);
610+
let task_context = if let Some(task_context_hid) = self.task_context {
611+
self.expr_ident_mut(span, task_context_ident, task_context_hid)
612+
} else {
613+
// Use of `await` outside of an async context, we cannot use `task_context` here.
614+
self.expr_err(span)
615+
};
570616
let pin_ty_id = self.next_id();
571617
let new_unchecked_expr_kind = self.expr_call_std_assoc_fn(
572618
pin_ty_id,
@@ -575,14 +621,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
575621
"new_unchecked",
576622
arena_vec![self; ref_mut_pinned],
577623
);
578-
let new_unchecked =
579-
self.arena.alloc(self.expr(span, new_unchecked_expr_kind, ThinVec::new()));
580-
let unsafe_expr = self.expr_unsafe(new_unchecked);
581-
self.expr_call_std_path(
624+
let new_unchecked = self.expr(span, new_unchecked_expr_kind, ThinVec::new());
625+
let call = self.expr_call_std_path(
582626
gen_future_span,
583-
&[sym::future, sym::poll_with_tls_context],
584-
arena_vec![self; unsafe_expr],
585-
)
627+
&[sym::future, sym::poll_with_context],
628+
arena_vec![self; new_unchecked, task_context],
629+
);
630+
self.arena.alloc(self.expr_unsafe(call))
586631
};
587632

588633
// `::std::task::Poll::Ready(result) => break result`
@@ -622,14 +667,26 @@ impl<'hir> LoweringContext<'_, 'hir> {
622667
self.stmt_expr(span, match_expr)
623668
};
624669

670+
// task_context = yield ();
625671
let yield_stmt = {
626672
let unit = self.expr_unit(span);
627673
let yield_expr = self.expr(
628674
span,
629675
hir::ExprKind::Yield(unit, hir::YieldSource::Await),
630676
ThinVec::new(),
631677
);
632-
self.stmt_expr(span, yield_expr)
678+
let yield_expr = self.arena.alloc(yield_expr);
679+
680+
if let Some(task_context_hid) = self.task_context {
681+
let lhs = self.expr_ident(span, task_context_ident, task_context_hid);
682+
let assign =
683+
self.expr(span, hir::ExprKind::Assign(lhs, yield_expr, span), AttrVec::new());
684+
self.stmt_expr(span, assign)
685+
} else {
686+
// Use of `await` outside of an async context. Return `yield_expr` so that we can
687+
// proceed with type checking.
688+
self.stmt(span, hir::StmtKind::Semi(yield_expr))
689+
}
633690
};
634691

635692
let loop_block = self.block_all(span, arena_vec![self; inner_match_stmt, yield_stmt], None);

src/librustc_ast_lowering/item.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
814814
}
815815

816816
/// Construct `ExprKind::Err` for the given `span`.
817-
fn expr_err(&mut self, span: Span) -> hir::Expr<'hir> {
817+
crate fn expr_err(&mut self, span: Span) -> hir::Expr<'hir> {
818818
self.expr(span, hir::ExprKind::Err, AttrVec::new())
819819
}
820820

@@ -960,7 +960,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
960960
id
961961
}
962962

963-
fn lower_body(
963+
pub(super) fn lower_body(
964964
&mut self,
965965
f: impl FnOnce(&mut Self) -> (&'hir [hir::Param<'hir>], hir::Expr<'hir>),
966966
) -> hir::BodyId {

src/librustc_ast_lowering/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ struct LoweringContext<'a, 'hir: 'a> {
116116

117117
generator_kind: Option<hir::GeneratorKind>,
118118

119+
/// When inside an `async` context, this is the `HirId` of the
120+
/// `task_context` local bound to the resume argument of the generator.
121+
task_context: Option<hir::HirId>,
122+
119123
/// Used to get the current `fn`'s def span to point to when using `await`
120124
/// outside of an `async fn`.
121125
current_item: Option<Span>,
@@ -294,6 +298,7 @@ pub fn lower_crate<'a, 'hir>(
294298
item_local_id_counters: Default::default(),
295299
node_id_to_hir_id: IndexVec::new(),
296300
generator_kind: None,
301+
task_context: None,
297302
current_item: None,
298303
lifetimes_to_define: Vec::new(),
299304
is_collecting_in_band_lifetimes: false,

src/librustc_mir/borrow_check/type_check/input_output.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,16 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
6464
}
6565
};
6666

67+
debug!(
68+
"equate_inputs_and_outputs: normalized_input_tys = {:?}, local_decls = {:?}",
69+
normalized_input_tys, body.local_decls
70+
);
71+
6772
// Equate expected input tys with those in the MIR.
6873
for (&normalized_input_ty, argument_index) in normalized_input_tys.iter().zip(0..) {
6974
// In MIR, argument N is stored in local N+1.
7075
let local = Local::new(argument_index + 1);
7176

72-
debug!("equate_inputs_and_outputs: normalized_input_ty = {:?}", normalized_input_ty);
73-
7477
let mir_input_ty = body.local_decls[local].ty;
7578
let mir_input_span = body.local_decls[local].source_info.span;
7679
self.equate_normalized_input_or_output(

src/librustc_span/symbol.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ symbols! {
544544
plugin_registrar,
545545
plugins,
546546
Poll,
547-
poll_with_tls_context,
547+
poll_with_context,
548548
powerpc_target_feature,
549549
precise_pointer_size_matching,
550550
pref_align_of,
@@ -720,6 +720,7 @@ symbols! {
720720
target_has_atomic_load_store,
721721
target_thread_local,
722722
task,
723+
_task_context,
723724
tbm_target_feature,
724725
termination_trait,
725726
termination_trait_test,

src/libstd/future.rs

+18-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
//! Asynchronous values.
22
3-
use core::cell::Cell;
4-
use core::marker::Unpin;
5-
use core::ops::{Drop, Generator, GeneratorState};
6-
use core::option::Option;
7-
use core::pin::Pin;
8-
use core::ptr::NonNull;
9-
use core::task::{Context, Poll};
3+
#[cfg(bootstrap)]
4+
use core::{
5+
cell::Cell,
6+
marker::Unpin,
7+
ops::{Drop, Generator, GeneratorState},
8+
pin::Pin,
9+
ptr::NonNull,
10+
task::{Context, Poll},
11+
};
1012

1113
#[doc(inline)]
1214
#[stable(feature = "futures_api", since = "1.36.0")]
@@ -17,22 +19,26 @@ pub use core::future::*;
1719
/// This function returns a `GenFuture` underneath, but hides it in `impl Trait` to give
1820
/// better error messages (`impl Future` rather than `GenFuture<[closure.....]>`).
1921
// This is `const` to avoid extra errors after we recover from `const async fn`
22+
#[cfg(bootstrap)]
2023
#[doc(hidden)]
2124
#[unstable(feature = "gen_future", issue = "50547")]
2225
pub const fn from_generator<T: Generator<Yield = ()>>(x: T) -> impl Future<Output = T::Return> {
2326
GenFuture(x)
2427
}
2528

2629
/// A wrapper around generators used to implement `Future` for `async`/`await` code.
30+
#[cfg(bootstrap)]
2731
#[doc(hidden)]
2832
#[unstable(feature = "gen_future", issue = "50547")]
2933
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
3034
struct GenFuture<T: Generator<Yield = ()>>(T);
3135

3236
// We rely on the fact that async/await futures are immovable in order to create
3337
// self-referential borrows in the underlying generator.
38+
#[cfg(bootstrap)]
3439
impl<T: Generator<Yield = ()>> !Unpin for GenFuture<T> {}
3540

41+
#[cfg(bootstrap)]
3642
#[doc(hidden)]
3743
#[unstable(feature = "gen_future", issue = "50547")]
3844
impl<T: Generator<Yield = ()>> Future for GenFuture<T> {
@@ -48,12 +54,15 @@ impl<T: Generator<Yield = ()>> Future for GenFuture<T> {
4854
}
4955
}
5056

57+
#[cfg(bootstrap)]
5158
thread_local! {
5259
static TLS_CX: Cell<Option<NonNull<Context<'static>>>> = Cell::new(None);
5360
}
5461

62+
#[cfg(bootstrap)]
5563
struct SetOnDrop(Option<NonNull<Context<'static>>>);
5664

65+
#[cfg(bootstrap)]
5766
impl Drop for SetOnDrop {
5867
fn drop(&mut self) {
5968
TLS_CX.with(|tls_cx| {
@@ -64,13 +73,15 @@ impl Drop for SetOnDrop {
6473

6574
// Safety: the returned guard must drop before `cx` is dropped and before
6675
// any previous guard is dropped.
76+
#[cfg(bootstrap)]
6777
unsafe fn set_task_context(cx: &mut Context<'_>) -> SetOnDrop {
6878
// transmute the context's lifetime to 'static so we can store it.
6979
let cx = core::mem::transmute::<&mut Context<'_>, &mut Context<'static>>(cx);
7080
let old_cx = TLS_CX.with(|tls_cx| tls_cx.replace(Some(NonNull::from(cx))));
7181
SetOnDrop(old_cx)
7282
}
7383

84+
#[cfg(bootstrap)]
7485
#[doc(hidden)]
7586
#[unstable(feature = "gen_future", issue = "50547")]
7687
/// Polls a future in the current thread-local task waker.

0 commit comments

Comments
 (0)