From 427896dd7e39f1aaf3e3cbc15e5ddf77d45a6aec Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Wed, 24 Jan 2024 23:38:33 +0000 Subject: [PATCH] Construct body for by-move coroutine closure output --- .../src/type_check/input_output.rs | 1 + .../src/interpret/terminator.rs | 1 + compiler/rustc_hir_typeck/src/callee.rs | 1 + compiler/rustc_hir_typeck/src/closure.rs | 11 ++ compiler/rustc_hir_typeck/src/upvar.rs | 10 ++ compiler/rustc_middle/src/mir/mod.rs | 5 + compiler/rustc_middle/src/mir/mono.rs | 1 + compiler/rustc_middle/src/mir/visit.rs | 1 + compiler/rustc_middle/src/ty/instance.rs | 21 +++- compiler/rustc_middle/src/ty/mod.rs | 1 + compiler/rustc_middle/src/ty/sty.rs | 47 ++++++-- compiler/rustc_mir_transform/src/coroutine.rs | 3 + .../src/coroutine/by_move_body.rs | 108 ++++++++++++++++++ compiler/rustc_mir_transform/src/inline.rs | 1 + .../rustc_mir_transform/src/inline/cycle.rs | 1 + compiler/rustc_mir_transform/src/lib.rs | 4 + .../rustc_mir_transform/src/pass_manager.rs | 6 + compiler/rustc_mir_transform/src/shim.rs | 12 ++ compiler/rustc_monomorphize/src/collector.rs | 3 +- .../rustc_monomorphize/src/partitioning.rs | 4 +- .../rustc_smir/src/rustc_smir/convert/ty.rs | 1 + .../src/solve/assembly/structural_traits.rs | 1 + .../src/traits/project.rs | 2 + ...await.b-{closure#0}.coroutine_resume.0.mir | 2 + 24 files changed, 233 insertions(+), 15 deletions(-) create mode 100644 compiler/rustc_mir_transform/src/coroutine/by_move_body.rs diff --git a/compiler/rustc_borrowck/src/type_check/input_output.rs b/compiler/rustc_borrowck/src/type_check/input_output.rs index a3e5088ee09..ace9c5ae71d 100644 --- a/compiler/rustc_borrowck/src/type_check/input_output.rs +++ b/compiler/rustc_borrowck/src/type_check/input_output.rs @@ -85,6 +85,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { self.tcx(), ty::CoroutineArgsParts { parent_args: args.parent_args(), + kind_ty: Ty::from_closure_kind(self.tcx(), args.kind()), resume_ty: next_ty_var(), yield_ty: next_ty_var(), witness: next_ty_var(), diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs index 4c8f68b25b5..b8d6836da14 100644 --- a/compiler/rustc_const_eval/src/interpret/terminator.rs +++ b/compiler/rustc_const_eval/src/interpret/terminator.rs @@ -546,6 +546,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs index 1858b2770cd..730a475f630 100644 --- a/compiler/rustc_hir_typeck/src/callee.rs +++ b/compiler/rustc_hir_typeck/src/callee.rs @@ -183,6 +183,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { coroutine_closure_sig.to_coroutine( self.tcx, closure_args.parent_args(), + closure_args.kind_ty(), self.tcx.coroutine_for_closure(def_id), tupled_upvars_ty, ), diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index 1d024efdd49..014293c1f83 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -175,10 +175,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { interior, )); + let kind_ty = match kind { + hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure) => self + .next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::ClosureSynthetic, + span: expr_span, + }), + _ => tcx.types.unit, + }; + let coroutine_args = ty::CoroutineArgs::new( tcx, ty::CoroutineArgsParts { parent_args, + kind_ty, resume_ty, yield_ty, return_ty: liberated_sig.output(), @@ -256,6 +266,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { sig.to_coroutine( tcx, parent_args, + closure_kind_ty, tcx.coroutine_for_closure(expr_def_id), coroutine_upvars_ty, ) diff --git a/compiler/rustc_hir_typeck/src/upvar.rs b/compiler/rustc_hir_typeck/src/upvar.rs index b087d6d9e57..d4e072976fa 100644 --- a/compiler/rustc_hir_typeck/src/upvar.rs +++ b/compiler/rustc_hir_typeck/src/upvar.rs @@ -393,6 +393,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { args.as_coroutine_closure().coroutine_captures_by_ref_ty(), coroutine_captures_by_ref_ty, ); + + let ty::Coroutine(_, args) = *self.typeck_results.borrow().expr_ty(body.value).kind() + else { + bug!(); + }; + self.demand_eqtype( + span, + args.as_coroutine().kind_ty(), + Ty::from_closure_kind(self.tcx, closure_kind), + ); } self.log_closure_min_capture_info(closure_def_id, span); diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index c9e69253701..d88e9261e5a 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -262,6 +262,10 @@ pub struct CoroutineInfo<'tcx> { /// Coroutine drop glue. This field is populated after the state transform pass. pub coroutine_drop: Option>, + /// The body of the coroutine, modified to take its upvars by move. + /// TODO: + pub by_move_body: Option>, + /// The layout of a coroutine. This field is populated after the state transform pass. pub coroutine_layout: Option>, @@ -281,6 +285,7 @@ impl<'tcx> CoroutineInfo<'tcx> { coroutine_kind, yield_ty: Some(yield_ty), resume_ty: Some(resume_ty), + by_move_body: None, coroutine_drop: None, coroutine_layout: None, } diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs index 4a29171d8bf..e6d1535fdf2 100644 --- a/compiler/rustc_middle/src/mir/mono.rs +++ b/compiler/rustc_middle/src/mir/mono.rs @@ -403,6 +403,7 @@ impl<'tcx> CodegenUnit<'tcx> { | InstanceDef::Virtual(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs index 6bc58adea0f..ce1859d6ada 100644 --- a/compiler/rustc_middle/src/mir/visit.rs +++ b/compiler/rustc_middle/src/mir/visit.rs @@ -346,6 +346,7 @@ macro_rules! make_mir_visitor { ty::InstanceDef::ThreadLocalShim(_def_id) | ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } | ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } | + ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: _def_id } | ty::InstanceDef::DropGlue(_def_id, None) => {} ty::InstanceDef::FnPtrShim(_def_id, ty) | diff --git a/compiler/rustc_middle/src/ty/instance.rs b/compiler/rustc_middle/src/ty/instance.rs index 41ae136851e..44bf3c32b48 100644 --- a/compiler/rustc_middle/src/ty/instance.rs +++ b/compiler/rustc_middle/src/ty/instance.rs @@ -101,6 +101,9 @@ pub enum InstanceDef<'tcx> { target_kind: ty::ClosureKind, }, + /// TODO: + CoroutineByMoveShim { coroutine_def_id: DefId }, + /// Compiler-generated accessor for thread locals which returns a reference to the thread local /// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking /// native support. @@ -186,6 +189,7 @@ impl<'tcx> InstanceDef<'tcx> { coroutine_closure_def_id: def_id, target_kind: _, } + | ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: def_id } | InstanceDef::DropGlue(def_id, _) | InstanceDef::CloneShim(def_id, _) | InstanceDef::FnPtrAddrShim(def_id, _) => def_id, @@ -206,6 +210,7 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::Intrinsic(..) | InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::FnPtrAddrShim(..) => None, @@ -302,6 +307,7 @@ impl<'tcx> InstanceDef<'tcx> { | InstanceDef::DropGlue(_, Some(_)) => false, InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::Item(_) | InstanceDef::Intrinsic(..) @@ -340,6 +346,7 @@ fn fmt_instance( InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"), InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"), InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"), + InstanceDef::CoroutineByMoveShim { .. } => write!(f, " - shim"), InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"), InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"), InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"), @@ -631,7 +638,19 @@ impl<'tcx> Instance<'tcx> { }; if tcx.lang_items().get(coroutine_callable_item) == Some(trait_item_id) { - Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args }) + let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind() + else { + bug!() + }; + + if args.as_coroutine().kind_ty() == id_args.as_coroutine().kind_ty() { + Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args }) + } else { + Some(Instance { + def: ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id }, + args, + }) + } } else { // All other methods should be defaulted methods of the built-in trait. // This is important for `Iterator`'s combinators, but also useful for diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 05875a9798b..9ceb3ec3f61 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -1681,6 +1681,7 @@ impl<'tcx> TyCtxt<'tcx> { | ty::InstanceDef::Virtual(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) | ty::InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index e2a2e24f06d..8918a3735d6 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -399,6 +399,7 @@ impl<'tcx> CoroutineClosureSignature<'tcx> { self, tcx: TyCtxt<'tcx>, parent_args: &'tcx [GenericArg<'tcx>], + kind_ty: Ty<'tcx>, coroutine_def_id: DefId, tupled_upvars_ty: Ty<'tcx>, ) -> Ty<'tcx> { @@ -406,6 +407,7 @@ impl<'tcx> CoroutineClosureSignature<'tcx> { tcx, ty::CoroutineArgsParts { parent_args, + kind_ty, resume_ty: self.resume_ty, yield_ty: self.yield_ty, return_ty: self.return_ty, @@ -436,7 +438,13 @@ impl<'tcx> CoroutineClosureSignature<'tcx> { env_region, ); - self.to_coroutine(tcx, parent_args, coroutine_def_id, tupled_upvars_ty) + self.to_coroutine( + tcx, + parent_args, + Ty::from_closure_kind(tcx, closure_kind), + coroutine_def_id, + tupled_upvars_ty, + ) } /// Given a closure kind, compute the tupled upvars that the given coroutine would return. @@ -488,6 +496,8 @@ pub struct CoroutineArgs<'tcx> { pub struct CoroutineArgsParts<'tcx> { /// This is the args of the typeck root. pub parent_args: &'tcx [GenericArg<'tcx>], + // TODO: why + pub kind_ty: Ty<'tcx>, pub resume_ty: Ty<'tcx>, pub yield_ty: Ty<'tcx>, pub return_ty: Ty<'tcx>, @@ -506,6 +516,7 @@ impl<'tcx> CoroutineArgs<'tcx> { pub fn new(tcx: TyCtxt<'tcx>, parts: CoroutineArgsParts<'tcx>) -> CoroutineArgs<'tcx> { CoroutineArgs { args: tcx.mk_args_from_iter(parts.parent_args.iter().copied().chain([ + parts.kind_ty.into(), parts.resume_ty.into(), parts.yield_ty.into(), parts.return_ty.into(), @@ -519,16 +530,23 @@ impl<'tcx> CoroutineArgs<'tcx> { /// The ordering assumed here must match that used by `CoroutineArgs::new` above. fn split(self) -> CoroutineArgsParts<'tcx> { match self.args[..] { - [ref parent_args @ .., resume_ty, yield_ty, return_ty, witness, tupled_upvars_ty] => { - CoroutineArgsParts { - parent_args, - resume_ty: resume_ty.expect_ty(), - yield_ty: yield_ty.expect_ty(), - return_ty: return_ty.expect_ty(), - witness: witness.expect_ty(), - tupled_upvars_ty: tupled_upvars_ty.expect_ty(), - } - } + [ + ref parent_args @ .., + kind_ty, + resume_ty, + yield_ty, + return_ty, + witness, + tupled_upvars_ty, + ] => CoroutineArgsParts { + parent_args, + kind_ty: kind_ty.expect_ty(), + resume_ty: resume_ty.expect_ty(), + yield_ty: yield_ty.expect_ty(), + return_ty: return_ty.expect_ty(), + witness: witness.expect_ty(), + tupled_upvars_ty: tupled_upvars_ty.expect_ty(), + }, _ => bug!("coroutine args missing synthetics"), } } @@ -538,6 +556,11 @@ impl<'tcx> CoroutineArgs<'tcx> { self.split().parent_args } + // TODO: + pub fn kind_ty(self) -> Ty<'tcx> { + self.split().kind_ty + } + /// This describes the types that can be contained in a coroutine. /// It will be a type variable initially and unified in the last stages of typeck of a body. /// It contains a tuple of all the types that could end up on a coroutine frame. @@ -1628,7 +1651,7 @@ impl<'tcx> Ty<'tcx> { ) -> Ty<'tcx> { debug_assert_eq!( coroutine_args.len(), - tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 5, + tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 6, "coroutine constructed with incorrect number of substitutions" ); Ty::new(tcx, Coroutine(def_id, coroutine_args)) diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index bde879f6067..297b2fa143d 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -50,6 +50,9 @@ //! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing. //! Otherwise it drops all the values in scope at the last suspension point. +mod by_move_body; +pub use by_move_body::ByMoveBody; + use crate::abort_unwinding_calls; use crate::deref_separator::deref_finder; use crate::errors; diff --git a/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs new file mode 100644 index 00000000000..4e3e70bdafe --- /dev/null +++ b/compiler/rustc_mir_transform/src/coroutine/by_move_body.rs @@ -0,0 +1,108 @@ +use rustc_data_structures::fx::FxIndexSet; +use rustc_hir as hir; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::{self, MirPass}; +use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt}; +use rustc_target::abi::FieldIdx; + +pub struct ByMoveBody; + +impl<'tcx> MirPass<'tcx> for ByMoveBody { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + let Some(coroutine_def_id) = body.source.def_id().as_local() else { + return; + }; + let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) = + tcx.coroutine_kind(coroutine_def_id) + else { + return; + }; + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() }; + if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce { + return; + } + + let mut by_ref_fields = FxIndexSet::default(); + let by_move_upvars = Ty::new_tup_from_iter( + tcx, + tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| { + if capture.is_by_ref() { + by_ref_fields.insert(FieldIdx::from_usize(idx)); + } + capture.place.ty() + }), + ); + let by_move_coroutine_ty = Ty::new_coroutine( + tcx, + coroutine_def_id.to_def_id(), + ty::CoroutineArgs::new( + tcx, + ty::CoroutineArgsParts { + parent_args: args.as_coroutine().parent_args(), + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), + resume_ty: args.as_coroutine().resume_ty(), + yield_ty: args.as_coroutine().yield_ty(), + return_ty: args.as_coroutine().return_ty(), + witness: args.as_coroutine().witness(), + tupled_upvars_ty: by_move_upvars, + }, + ) + .args, + ); + + let mut by_move_body = body.clone(); + MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); + by_move_body.source = mir::MirSource { + instance: InstanceDef::CoroutineByMoveShim { + coroutine_def_id: coroutine_def_id.to_def_id(), + }, + promoted: None, + }; + + body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); + } +} + +struct MakeByMoveBody<'tcx> { + tcx: TyCtxt<'tcx>, + by_ref_fields: FxIndexSet, + by_move_coroutine_ty: Ty<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place( + &mut self, + place: &mut mir::Place<'tcx>, + context: mir::visit::PlaceContext, + location: mir::Location, + ) { + if place.local == ty::CAPTURE_STRUCT_LOCAL + && !place.projection.is_empty() + && let mir::ProjectionElem::Field(idx, ty) = place.projection[0] + && self.by_ref_fields.contains(&idx) + { + let (begin, end) = place.projection[1..].split_first().unwrap(); + assert_eq!(*begin, mir::ProjectionElem::Deref); + *place = mir::Place { + local: place.local, + projection: self.tcx.mk_place_elems_from_iter( + [mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)] + .into_iter() + .chain(end.iter().copied()), + ), + }; + } + self.super_place(place, context, location); + } + + fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) { + if local == ty::CAPTURE_STRUCT_LOCAL { + local_decl.ty = self.by_move_coroutine_ty; + } + } +} diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 7c731b070a7..24bc84a235c 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -318,6 +318,7 @@ impl<'tcx> Inliner<'tcx> { | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::ThreadLocalShim(..) diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index 3f3dc9145b6..77ff780393e 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -88,6 +88,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>( | InstanceDef::FnPtrShim(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::ThreadLocalShim { .. } | InstanceDef::CloneShim(..) => {} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 69f93fa3a0e..031515ea958 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -307,6 +307,10 @@ fn mir_const(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal> { &Lint(check_packed_ref::CheckPackedRef), &Lint(check_const_item_mutation::CheckConstItemMutation), &Lint(function_item_references::FunctionItemReferences), + // If this is an async closure's output coroutine, generate + // by-move and by-mut bodies if needed. We do this first so + // they can be optimized in lockstep with their parent bodies. + &coroutine::ByMoveBody, // What we need to do constant evaluation. &simplify::SimplifyCfg::Initial, &rustc_peek::SanityCheck, // Just a lint diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index c1ef2b9f887..c7e770904fb 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -189,6 +189,12 @@ fn run_passes_inner<'tcx>( body.pass_count = 1; } + + if let Some(coroutine) = body.coroutine.as_mut() + && let Some(by_move_body) = coroutine.by_move_body.as_mut() + { + run_passes_inner(tcx, by_move_body, passes, phase_change, validate_each); + } } pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) { diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 29b83f58ef5..668ccdd8735 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -81,6 +81,18 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' } }, + ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id } => { + return tcx + .optimized_mir(coroutine_def_id) + .coroutine + .as_ref() + .unwrap() + .by_move_body + .as_ref() + .unwrap() + .clone(); + } + ty::InstanceDef::DropGlue(def_id, ty) => { // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 698fd634114..cf3c8e1fdd3 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -983,7 +983,8 @@ fn visit_instance_use<'tcx>( | ty::InstanceDef::VTableShim(..) | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::ClosureOnceShim { .. } - | InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::Item(..) | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 46bd33c89e7..22b35c4344b 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -620,7 +620,8 @@ fn characteristic_def_id_of_mono_item<'tcx>( | ty::InstanceDef::ReifyShim(..) | ty::InstanceDef::FnPtrShim(..) | ty::InstanceDef::ClosureOnceShim { .. } - | InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::Intrinsic(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::Virtual(..) @@ -785,6 +786,7 @@ fn mono_item_visibility<'tcx>( | InstanceDef::Intrinsic(..) | InstanceDef::ClosureOnceShim { .. } | InstanceDef::ConstructCoroutineInClosureShim { .. } + | InstanceDef::CoroutineByMoveShim { .. } | InstanceDef::DropGlue(..) | InstanceDef::CloneShim(..) | InstanceDef::FnPtrAddrShim(..) => return Visibility::Hidden, diff --git a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs index e0e9815cf40..3c1858e920b 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs @@ -800,6 +800,7 @@ impl<'tcx> Stable<'tcx> for ty::Instance<'tcx> { | ty::InstanceDef::FnPtrAddrShim(..) | ty::InstanceDef::ClosureOnceShim { .. } | ty::InstanceDef::ConstructCoroutineInClosureShim { .. } + | ty::InstanceDef::CoroutineByMoveShim { .. } | ty::InstanceDef::ThreadLocalShim(..) | ty::InstanceDef::DropGlue(..) | ty::InstanceDef::CloneShim(..) diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs index c35134c78eb..0699026117d 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs @@ -366,6 +366,7 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc let coroutine_ty = sig.to_coroutine( tcx, args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), tcx.coroutine_for_closure(def_id), tupled_upvars_ty, ); diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 648f14beaa7..db1e89ae72f 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -2505,6 +2505,7 @@ fn confirm_async_closure_candidate<'cx, 'tcx>( let coroutine_ty = sig.to_coroutine( tcx, args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), tcx.coroutine_for_closure(def_id), tupled_upvars_ty, ); @@ -2533,6 +2534,7 @@ fn confirm_async_closure_candidate<'cx, 'tcx>( let coroutine_ty = sig.to_coroutine( tcx, args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), tcx.coroutine_for_closure(def_id), tupled_upvars_ty, ); diff --git a/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir b/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir index 3c0d4008c90..9c8cf8763fd 100644 --- a/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir +++ b/tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir @@ -5,6 +5,7 @@ ty: Coroutine( DefId(0:4 ~ async_await[ccf8]::a::{closure#0}), [ + (), std::future::ResumeTy, (), (), @@ -22,6 +23,7 @@ ty: Coroutine( DefId(0:4 ~ async_await[ccf8]::a::{closure#0}), [ + (), std::future::ResumeTy, (), (),