From b7bc8d5cb7685bd8e35d7b1c9d3011b043abf775 Mon Sep 17 00:00:00 2001 From: bjorn3 <17426603+bjorn3@users.noreply.github.com> Date: Thu, 23 Nov 2023 20:02:45 +0000 Subject: [PATCH] Fix fn_sig_for_fn_abi and the coroutine transform for generators There were three issues previously: * The self argument was pinned, despite Iterator::next taking an unpinned mutable reference. * A resume argument was passed, despite Iterator::next not having one. * The return value was CoroutineState rather than Option While these things just so happened to work with the LLVM backend, cg_clif does much stricter checks when trying to assign a value to a place. In addition it can't handle the mismatch between the amount of arguments specified by the FnAbi and the FnSig. --- .../build_system/tests.rs | 9 ++++ compiler/rustc_codegen_cranelift/config.txt | 1 + .../example/gen_block_iterate.rs | 36 +++++++++++++ compiler/rustc_codegen_cranelift/rustfmt.toml | 5 +- compiler/rustc_mir_transform/src/coroutine.rs | 32 ++++++++++- compiler/rustc_ty_utils/src/abi.rs | 54 ++++++++++++++++--- rustfmt.toml | 1 + 7 files changed, 129 insertions(+), 9 deletions(-) create mode 100644 compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs diff --git a/compiler/rustc_codegen_cranelift/build_system/tests.rs b/compiler/rustc_codegen_cranelift/build_system/tests.rs index ff71a567ed3..aa50dbfdf35 100644 --- a/compiler/rustc_codegen_cranelift/build_system/tests.rs +++ b/compiler/rustc_codegen_cranelift/build_system/tests.rs @@ -100,6 +100,15 @@ const BASE_SYSROOT_SUITE: &[TestCase] = &[ TestCase::build_bin_and_run("aot.issue-72793", "example/issue-72793.rs", &[]), TestCase::build_bin("aot.issue-59326", "example/issue-59326.rs"), TestCase::build_bin_and_run("aot.neon", "example/neon.rs", &[]), + TestCase::custom("aot.gen_block_iterate", &|runner| { + runner.run_rustc([ + "example/gen_block_iterate.rs", + "--edition", + "2024", + "-Zunstable-options", + ]); + runner.run_out_command("gen_block_iterate", &[]); + }), ]; pub(crate) static RAND_REPO: GitRepo = GitRepo::github( diff --git a/compiler/rustc_codegen_cranelift/config.txt b/compiler/rustc_codegen_cranelift/config.txt index 2ccdc7d7874..3cf295c003e 100644 --- a/compiler/rustc_codegen_cranelift/config.txt +++ b/compiler/rustc_codegen_cranelift/config.txt @@ -43,6 +43,7 @@ aot.mod_bench aot.issue-72793 aot.issue-59326 aot.neon +aot.gen_block_iterate testsuite.extended_sysroot test.rust-random/rand diff --git a/compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs b/compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs new file mode 100644 index 00000000000..14bd23e77ea --- /dev/null +++ b/compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs @@ -0,0 +1,36 @@ +// Copied from https://github.com/rust-lang/rust/blob/46455dc65069387f2dc46612f13fd45452ab301a/tests/ui/coroutine/gen_block_iterate.rs +// revisions: next old +//compile-flags: --edition 2024 -Zunstable-options +//[next] compile-flags: -Ztrait-solver=next +// run-pass +#![feature(gen_blocks)] + +fn foo() -> impl Iterator { + gen { yield 42; for x in 3..6 { yield x } } +} + +fn moved() -> impl Iterator { + let mut x = "foo".to_string(); + gen move { + yield 42; + if x == "foo" { return } + x.clear(); + for x in 3..6 { yield x } + } +} + +fn main() { + let mut iter = foo(); + assert_eq!(iter.next(), Some(42)); + assert_eq!(iter.next(), Some(3)); + assert_eq!(iter.next(), Some(4)); + assert_eq!(iter.next(), Some(5)); + assert_eq!(iter.next(), None); + // `gen` blocks are fused + assert_eq!(iter.next(), None); + + let mut iter = moved(); + assert_eq!(iter.next(), Some(42)); + assert_eq!(iter.next(), None); + +} diff --git a/compiler/rustc_codegen_cranelift/rustfmt.toml b/compiler/rustc_codegen_cranelift/rustfmt.toml index ebeca8662a5..0f884187add 100644 --- a/compiler/rustc_codegen_cranelift/rustfmt.toml +++ b/compiler/rustc_codegen_cranelift/rustfmt.toml @@ -1,4 +1,7 @@ -ignore = ["y.rs"] +ignore = [ + "y.rs", + "example/gen_block_iterate.rs", # uses edition 2024 +] # Matches rustfmt.toml of rustc version = "Two" diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index aa4d8ddad56..42540911785 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -617,6 +617,22 @@ fn replace_resume_ty_local<'tcx>( } } +/// Transforms the `body` of the coroutine applying the following transform: +/// +/// - Remove the `resume` argument. +/// +/// Ideally the async lowering would not add the `resume` argument. +/// +/// The async lowering step and the type / lifetime inference / checking are +/// still using the `resume` argument for the time being. After this transform, +/// the coroutine body doesn't have the `resume` argument. +fn transform_gen_context<'tcx>(_tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // This leaves the local representing the `resume` argument in place, + // but turns it into a regular local variable. This is cheaper than + // adjusting all local references in the body after removing it. + body.arg_count = 1; +} + struct LivenessInfo { /// Which locals are live across any suspension point. saved_locals: CoroutineSavedLocals, @@ -1337,7 +1353,15 @@ fn create_coroutine_resume_function<'tcx>( insert_switch(body, cases, &transform, TerminatorKind::Unreachable); make_coroutine_state_argument_indirect(tcx, body); - make_coroutine_state_argument_pinned(tcx, body); + + match coroutine_kind { + // Iterator::next doesn't accept a pinned argument, + // unlike for all other coroutine kinds. + CoroutineKind::Gen(_) => {} + _ => { + make_coroutine_state_argument_pinned(tcx, body); + } + } // Make sure we remove dead blocks to remove // unrelated code from the drop part of the function @@ -1504,6 +1528,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { }; let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_))); + let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_))); let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() { CoroutineKind::Async(_) => { // Compute Poll @@ -1609,6 +1634,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform { body.arg_count = 2; // self, resume arg body.spread_arg = None; + // Remove the context argument within generator bodies. + if is_gen_kind { + transform_gen_context(tcx, body); + } + // The original arguments to the function are no longer arguments, mark them as such. // Otherwise they'll conflict with our new arguments, which although they don't have // argument_index set, will get emitted as unnamed arguments. diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index 737acfbc600..8ea78b9b532 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -112,7 +112,13 @@ fn fn_sig_for_fn_abi<'tcx>( let pin_did = tcx.require_lang_item(LangItem::Pin, None); let pin_adt_ref = tcx.adt_def(pin_did); let pin_args = tcx.mk_args(&[env_ty.into()]); - let env_ty = Ty::new_adt(tcx, pin_adt_ref, pin_args); + let env_ty = if tcx.coroutine_is_gen(did) { + // Iterator::next doesn't accept a pinned argument, + // unlike for all other coroutine kinds. + env_ty + } else { + Ty::new_adt(tcx, pin_adt_ref, pin_args) + }; let sig = sig.skip_binder(); // The `FnSig` and the `ret_ty` here is for a coroutines main @@ -121,6 +127,8 @@ fn fn_sig_for_fn_abi<'tcx>( // function in case this is a special coroutine backing an async construct. let (resume_ty, ret_ty) = if tcx.coroutine_is_async(did) { // The signature should be `Future::poll(_, &mut Context<'_>) -> Poll` + assert_eq!(sig.yield_ty, tcx.types.unit); + let poll_did = tcx.require_lang_item(LangItem::Poll, None); let poll_adt_ref = tcx.adt_def(poll_did); let poll_args = tcx.mk_args(&[sig.return_ty.into()]); @@ -140,7 +148,30 @@ fn fn_sig_for_fn_abi<'tcx>( } let context_mut_ref = Ty::new_task_context(tcx); - (context_mut_ref, ret_ty) + (Some(context_mut_ref), ret_ty) + } else if tcx.coroutine_is_gen(did) { + // The signature should be `Iterator::next(_) -> Option` + let option_did = tcx.require_lang_item(LangItem::Option, None); + let option_adt_ref = tcx.adt_def(option_did); + let option_args = tcx.mk_args(&[sig.yield_ty.into()]); + let ret_ty = Ty::new_adt(tcx, option_adt_ref, option_args); + + assert_eq!(sig.return_ty, tcx.types.unit); + + // We have to replace the `ResumeTy` that is used for type and borrow checking + // with `()` which is used in codegen. + #[cfg(debug_assertions)] + { + if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() { + let expected_adt = + tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None)); + assert_eq!(*resume_ty_adt, expected_adt); + } else { + panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty); + }; + } + + (None, ret_ty) } else { // The signature should be `Coroutine::resume(_, Resume) -> CoroutineState` let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); @@ -148,19 +179,28 @@ fn fn_sig_for_fn_abi<'tcx>( let state_args = tcx.mk_args(&[sig.yield_ty.into(), sig.return_ty.into()]); let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args); - (sig.resume_ty, ret_ty) + (Some(sig.resume_ty), ret_ty) }; - ty::Binder::bind_with_vars( + let fn_sig = if let Some(resume_ty) = resume_ty { tcx.mk_fn_sig( [env_ty, resume_ty], ret_ty, false, hir::Unsafety::Normal, rustc_target::spec::abi::Abi::Rust, - ), - bound_vars, - ) + ) + } else { + // `Iterator::next` doesn't have a `resume` argument. + tcx.mk_fn_sig( + [env_ty], + ret_ty, + false, + hir::Unsafety::Normal, + rustc_target::spec::abi::Abi::Rust, + ) + }; + ty::Binder::bind_with_vars(fn_sig, bound_vars) } _ => bug!("unexpected type {:?} in Instance::fn_sig", ty), } diff --git a/rustfmt.toml b/rustfmt.toml index 88700779e87..e292a310742 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -39,4 +39,5 @@ ignore = [ # these are ignored by a standard cargo fmt run "compiler/rustc_codegen_cranelift/y.rs", # running rustfmt breaks this file "compiler/rustc_codegen_cranelift/scripts", + "compiler/rustc_codegen_cranelift/example/gen_block_iterate.rs", # uses edition 2024 ]