Desugar for await loops

This commit is contained in:
Eric Holk 2023-12-08 17:00:11 -08:00
parent 27d6539a46
commit 97df0d3657
No known key found for this signature in database
GPG Key ID: 8EA6B43ED4CE0911
8 changed files with 125 additions and 30 deletions

View File

@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
), ),
ExprKind::Try(sub_expr) => self.lower_expr_try(e.span, sub_expr), ExprKind::Try(sub_expr) => self.lower_expr_try(e.span, sub_expr),
ExprKind::Paren(_) | ExprKind::ForLoop{..} => { ExprKind::Paren(_) | ExprKind::ForLoop { .. } => {
unreachable!("already handled") unreachable!("already handled")
} }
@ -874,6 +874,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
/// } /// }
/// ``` /// ```
fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> { fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
let expr = self.arena.alloc(self.lower_expr_mut(expr));
self.make_lowered_await(await_kw_span, expr, FutureKind::Future)
}
/// Takes an expr that has already been lowered and generates a desugared await loop around it
fn make_lowered_await(
&mut self,
await_kw_span: Span,
expr: &'hir hir::Expr<'hir>,
await_kind: FutureKind,
) -> hir::ExprKind<'hir> {
let full_span = expr.span.to(await_kw_span); let full_span = expr.span.to(await_kw_span);
let is_async_gen = match self.coroutine_kind { let is_async_gen = match self.coroutine_kind {
@ -887,13 +898,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
} }
}; };
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None); let features = match await_kind {
FutureKind::Future => None,
FutureKind::AsyncIterator => Some(self.allow_for_await.clone()),
};
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, features);
let gen_future_span = self.mark_span_with_reason( let gen_future_span = self.mark_span_with_reason(
DesugaringKind::Await, DesugaringKind::Await,
full_span, full_span,
Some(self.allow_gen_future.clone()), Some(self.allow_gen_future.clone()),
); );
let expr = self.lower_expr_mut(expr);
let expr_hir_id = expr.hir_id; let expr_hir_id = expr.hir_id;
// Note that the name of this binding must not be changed to something else because // Note that the name of this binding must not be changed to something else because
@ -933,11 +947,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::LangItem::GetContext, hir::LangItem::GetContext,
arena_vec![self; task_context], arena_vec![self; task_context],
); );
let call = self.expr_call_lang_item_fn( let call = match await_kind {
span, FutureKind::Future => self.expr_call_lang_item_fn(
hir::LangItem::FuturePoll, span,
arena_vec![self; new_unchecked, get_context], hir::LangItem::FuturePoll,
); arena_vec![self; new_unchecked, get_context],
),
FutureKind::AsyncIterator => self.expr_call_lang_item_fn(
span,
hir::LangItem::AsyncIteratorPollNext,
arena_vec![self; new_unchecked, get_context],
),
};
self.arena.alloc(self.expr_unsafe(call)) self.arena.alloc(self.expr_unsafe(call))
}; };
@ -1021,11 +1042,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
let awaitee_arm = self.arm(awaitee_pat, loop_expr); let awaitee_arm = self.arm(awaitee_pat, loop_expr);
// `match ::std::future::IntoFuture::into_future(<expr>) { ... }` // `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
let into_future_expr = self.expr_call_lang_item_fn( let into_future_expr = match await_kind {
span, FutureKind::Future => self.expr_call_lang_item_fn(
hir::LangItem::IntoFutureIntoFuture, span,
arena_vec![self; expr], hir::LangItem::IntoFutureIntoFuture,
); arena_vec![self; *expr],
),
// Not needed for `for await` because we expect to have already called
// `IntoAsyncIterator::into_async_iter` on it.
FutureKind::AsyncIterator => expr,
};
// match <into_future_expr> { // match <into_future_expr> {
// mut __awaitee => loop { .. } // mut __awaitee => loop { .. }
@ -1673,7 +1699,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
head: &Expr, head: &Expr,
body: &Block, body: &Block,
opt_label: Option<Label>, opt_label: Option<Label>,
_loop_kind: ForLoopKind, loop_kind: ForLoopKind,
) -> hir::Expr<'hir> { ) -> hir::Expr<'hir> {
let head = self.lower_expr_mut(head); let head = self.lower_expr_mut(head);
let pat = self.lower_pat(pat); let pat = self.lower_pat(pat);
@ -1702,17 +1728,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
let (iter_pat, iter_pat_nid) = let (iter_pat, iter_pat_nid) =
self.pat_ident_binding_mode(head_span, iter, hir::BindingAnnotation::MUT); self.pat_ident_binding_mode(head_span, iter, hir::BindingAnnotation::MUT);
// `match Iterator::next(&mut iter) { ... }`
let match_expr = { let match_expr = {
let iter = self.expr_ident(head_span, iter, iter_pat_nid); let iter = self.expr_ident(head_span, iter, iter_pat_nid);
let ref_mut_iter = self.expr_mut_addr_of(head_span, iter); let next_expr = match loop_kind {
let next_expr = self.expr_call_lang_item_fn( ForLoopKind::For => {
head_span, // `Iterator::next(&mut iter)`
hir::LangItem::IteratorNext, let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
arena_vec![self; ref_mut_iter], self.expr_call_lang_item_fn(
); head_span,
hir::LangItem::IteratorNext,
arena_vec![self; ref_mut_iter],
)
}
ForLoopKind::ForAwait => {
// we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
// to make_lowered_await with `FutureKind::AsyncIterator` which will generator
// calls to `poll_next`. In user code, this would probably be a call to
// `Pin::as_mut` but here it's easy enough to do `new_unchecked`.
// `&mut iter`
let iter = self.expr_mut_addr_of(head_span, iter);
// `Pin::new_unchecked(...)`
let iter = self.arena.alloc(self.expr_call_lang_item_fn_mut(
head_span,
hir::LangItem::PinNewUnchecked,
arena_vec![self; iter],
));
// `unsafe { ... }`
let iter = self.arena.alloc(self.expr_unsafe(iter));
let kind = self.make_lowered_await(head_span, iter, FutureKind::AsyncIterator);
self.arena.alloc(hir::Expr { hir_id: self.next_id(), kind, span: head_span })
}
};
let arms = arena_vec![self; none_arm, some_arm]; let arms = arena_vec![self; none_arm, some_arm];
// `match $next_expr { ... }`
self.expr_match(head_span, next_expr, arms, hir::MatchSource::ForLoopDesugar) self.expr_match(head_span, next_expr, arms, hir::MatchSource::ForLoopDesugar)
}; };
let match_stmt = self.stmt_expr(for_span, match_expr); let match_stmt = self.stmt_expr(for_span, match_expr);
@ -1732,13 +1782,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
// `mut iter => { ... }` // `mut iter => { ... }`
let iter_arm = self.arm(iter_pat, loop_expr); let iter_arm = self.arm(iter_pat, loop_expr);
// `match ::std::iter::IntoIterator::into_iter(<head>) { ... }` let into_iter_expr = match loop_kind {
let into_iter_expr = { ForLoopKind::For => {
self.expr_call_lang_item_fn( // `::std::iter::IntoIterator::into_iter(<head>)`
head_span, self.expr_call_lang_item_fn(
hir::LangItem::IntoIterIntoIter, head_span,
arena_vec![self; head], hir::LangItem::IntoIterIntoIter,
) arena_vec![self; head],
)
}
ForLoopKind::ForAwait => self.arena.alloc(head),
}; };
let match_expr = self.arena.alloc(self.expr_match( let match_expr = self.arena.alloc(self.expr_match(
@ -2141,3 +2194,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
} }
} }
} }
/// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind
/// of future we are awaiting.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum FutureKind {
/// We are awaiting a normal future
Future,
/// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
/// a `for await` loop)
AsyncIterator,
}

View File

@ -130,6 +130,7 @@ struct LoweringContext<'a, 'hir> {
allow_try_trait: Lrc<[Symbol]>, allow_try_trait: Lrc<[Symbol]>,
allow_gen_future: Lrc<[Symbol]>, allow_gen_future: Lrc<[Symbol]>,
allow_async_iterator: Lrc<[Symbol]>, allow_async_iterator: Lrc<[Symbol]>,
allow_for_await: Lrc<[Symbol]>,
/// Mapping from generics `def_id`s to TAIT generics `def_id`s. /// Mapping from generics `def_id`s to TAIT generics `def_id`s.
/// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic /// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic
@ -174,6 +175,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
} else { } else {
[sym::gen_future].into() [sym::gen_future].into()
}, },
allow_for_await: [sym::async_iterator].into(),
// FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller` // FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller`
// interact with `gen`/`async gen` blocks // interact with `gen`/`async gen` blocks
allow_async_iterator: [sym::gen_future, sym::async_iterator].into(), allow_async_iterator: [sym::gen_future, sym::async_iterator].into(),

View File

@ -303,7 +303,7 @@ impl<'cx, 'a> Context<'cx, 'a> {
| ExprKind::Continue(_) | ExprKind::Continue(_)
| ExprKind::Err | ExprKind::Err
| ExprKind::Field(_, _) | ExprKind::Field(_, _)
| ExprKind::ForLoop {..} | ExprKind::ForLoop { .. }
| ExprKind::FormatArgs(_) | ExprKind::FormatArgs(_)
| ExprKind::IncludedBytes(..) | ExprKind::IncludedBytes(..)
| ExprKind::InlineAsm(_) | ExprKind::InlineAsm(_)

View File

@ -358,7 +358,7 @@ declare_features! (
/// Allows `#[track_caller]` on async functions. /// Allows `#[track_caller]` on async functions.
(unstable, async_fn_track_caller, "1.73.0", Some(110011)), (unstable, async_fn_track_caller, "1.73.0", Some(110011)),
/// Allows `for await` loops. /// Allows `for await` loops.
(unstable, async_for_loop, "CURRENT_RUSTC_VERSION", None), (unstable, async_for_loop, "CURRENT_RUSTC_VERSION", Some(118898)),
/// Allows builtin # foo() syntax /// Allows builtin # foo() syntax
(unstable, builtin_syntax, "1.71.0", Some(110680)), (unstable, builtin_syntax, "1.71.0", Some(110680)),
/// Treat `extern "C"` function as nounwind. /// Treat `extern "C"` function as nounwind.

View File

@ -307,6 +307,8 @@ language_item_table! {
Context, sym::Context, context, Target::Struct, GenericRequirement::None; Context, sym::Context, context, Target::Struct, GenericRequirement::None;
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None; FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
AsyncIteratorPollNext, sym::async_iterator_poll_next, async_iterator_poll_next, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::Exact(0);
Option, sym::Option, option_type, Target::Enum, GenericRequirement::None; Option, sym::Option, option_type, Target::Enum, GenericRequirement::None;
OptionSome, sym::Some, option_some_variant, Target::Variant, GenericRequirement::None; OptionSome, sym::Some, option_some_variant, Target::Variant, GenericRequirement::None;
OptionNone, sym::None, option_none_variant, Target::Variant, GenericRequirement::None; OptionNone, sym::None, option_none_variant, Target::Variant, GenericRequirement::None;

View File

@ -428,6 +428,7 @@ symbols! {
async_fn_track_caller, async_fn_track_caller,
async_for_loop, async_for_loop,
async_iterator, async_iterator,
async_iterator_poll_next,
atomic, atomic,
atomic_mod, atomic_mod,
atomics, atomics,
@ -894,6 +895,7 @@ symbols! {
instruction_set, instruction_set,
integer_: "integer", // underscore to avoid clashing with the function `sym::integer` below integer_: "integer", // underscore to avoid clashing with the function `sym::integer` below
integral, integral,
into_async_iter_into_iter,
into_future, into_future,
into_iter, into_iter,
intra_doc_pointers, intra_doc_pointers,

View File

@ -47,6 +47,7 @@ pub trait AsyncIterator {
/// Rust's usual rules apply: calls must never cause undefined behavior /// Rust's usual rules apply: calls must never cause undefined behavior
/// (memory corruption, incorrect use of `unsafe` functions, or the like), /// (memory corruption, incorrect use of `unsafe` functions, or the like),
/// regardless of the async iterator's state. /// regardless of the async iterator's state.
#[cfg_attr(not(bootstrap), lang = "async_iterator_poll_next")]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>;
/// Returns the bounds on the remaining length of the async iterator. /// Returns the bounds on the remaining length of the async iterator.

View File

@ -0,0 +1,24 @@
// run-pass
// edition: 2021
#![feature(async_iterator, async_iter_from_iter, const_waker, async_for_loop, noop_waker)]
use std::future::Future;
// make sure a simple for await loop works
async fn real_main() {
let iter = core::async_iter::from_iter(0..3);
let mut count = 0;
for await i in iter {
assert_eq!(i, count);
count += 1;
}
assert_eq!(count, 3);
}
fn main() {
let future = real_main();
let waker = std::task::Waker::noop();
let mut cx = &mut core::task::Context::from_waker(&waker);
let mut future = core::pin::pin!(future);
while let core::task::Poll::Pending = future.as_mut().poll(&mut cx) {}
}