From 9de6b70bb65c922c6e75e2439cb5a8cb9d30e2a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Esteban=20K=C3=BCber?= Date: Fri, 8 Mar 2024 22:59:53 +0000 Subject: [PATCH] Provide suggestion to dereference closure tail if appropriate When encoutnering a case like ```rust //@ run-rustfix use std::collections::HashMap; fn main() { let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3]; let mut counts = HashMap::new(); for num in vs { let count = counts.entry(num).or_insert(0); *count += 1; } let _ = counts.iter().max_by_key(|(_, v)| v); ``` produce the following suggestion ``` error: lifetime may not live long enough --> $DIR/return-value-lifetime-error.rs:13:47 | LL | let _ = counts.iter().max_by_key(|(_, v)| v); | ------- ^ returning this value requires that `'1` must outlive `'2` | | | | | return type of closure is &'2 &i32 | has type `&'1 (&i32, &i32)` | help: dereference the return value | LL | let _ = counts.iter().max_by_key(|(_, v)| **v); | ++ ``` Fix #50195. --- .../src/diagnostics/conflict_errors.rs | 12 +- .../src/diagnostics/region_errors.rs | 202 ++++++++++++++++++ compiler/rustc_hir_typeck/src/lib.rs | 25 ++- compiler/rustc_middle/src/query/keys.rs | 13 ++ compiler/rustc_middle/src/query/mod.rs | 3 + .../return-value-lifetime-error.fixed | 16 ++ .../closures/return-value-lifetime-error.rs | 16 ++ .../return-value-lifetime-error.stderr | 16 ++ 8 files changed, 298 insertions(+), 5 deletions(-) create mode 100644 tests/ui/closures/return-value-lifetime-error.fixed create mode 100644 tests/ui/closures/return-value-lifetime-error.rs create mode 100644 tests/ui/closures/return-value-lifetime-error.stderr diff --git a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs index 62e16d445c6..47bd24f1e14 100644 --- a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs +++ b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs @@ -1469,27 +1469,31 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> { let hir = tcx.hir(); let Some(body_id) = tcx.hir_node(self.mir_hir_id()).body_id() else { return }; struct FindUselessClone<'hir> { + tcx: TyCtxt<'hir>, + def_id: DefId, pub clones: Vec<&'hir hir::Expr<'hir>>, } impl<'hir> FindUselessClone<'hir> { - pub fn new() -> Self { - Self { clones: vec![] } + pub fn new(tcx: TyCtxt<'hir>, def_id: DefId) -> Self { + Self { tcx, def_id, clones: vec![] } } } impl<'v> Visitor<'v> for FindUselessClone<'v> { fn visit_expr(&mut self, ex: &'v hir::Expr<'v>) { - // FIXME: use `lookup_method_for_diagnostic`? if let hir::ExprKind::MethodCall(segment, _rcvr, args, _span) = ex.kind && segment.ident.name == sym::clone && args.len() == 0 + && let Some(def_id) = self.def_id.as_local() + && let Some(method) = self.tcx.lookup_method_for_diagnostic((def_id, ex.hir_id)) + && Some(self.tcx.parent(method)) == self.tcx.lang_items().clone_trait() { self.clones.push(ex); } hir::intravisit::walk_expr(self, ex); } } - let mut expr_finder = FindUselessClone::new(); + let mut expr_finder = FindUselessClone::new(tcx, self.mir_def_id().into()); let body = hir.body(body_id).value; expr_finder.visit_expr(body); diff --git a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs index c92fccc959f..8210727c6d9 100644 --- a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs +++ b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs @@ -20,12 +20,17 @@ use rustc_infer::infer::{ }; use rustc_middle::hir::place::PlaceBase; use rustc_middle::mir::{ConstraintCategory, ReturnConstraint}; +use rustc_middle::traits::ObligationCause; use rustc_middle::ty::GenericArgs; use rustc_middle::ty::TypeVisitor; use rustc_middle::ty::{self, RegionVid, Ty}; use rustc_middle::ty::{Region, TyCtxt}; use rustc_span::symbol::{kw, Ident}; use rustc_span::Span; +use rustc_trait_selection::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; +use rustc_trait_selection::infer::InferCtxtExt; +use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _; +use rustc_trait_selection::traits::Obligation; use crate::borrowck_errors; use crate::session_diagnostics::{ @@ -810,6 +815,7 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> { self.add_static_impl_trait_suggestion(&mut diag, *fr, fr_name, *outlived_fr); self.suggest_adding_lifetime_params(&mut diag, *fr, *outlived_fr); self.suggest_move_on_borrowing_closure(&mut diag); + self.suggest_deref_closure_value(&mut diag); diag } @@ -1039,6 +1045,202 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> { suggest_adding_lifetime_params(self.infcx.tcx, sub, ty_sup, ty_sub, diag); } + #[allow(rustc::diagnostic_outside_of_impl)] + #[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable + /// When encountering a lifetime error caused by the return type of a closure, check the + /// corresponding trait bound and see if dereferencing the closure return value would satisfy + /// them. If so, we produce a structured suggestion. + fn suggest_deref_closure_value(&self, diag: &mut Diag<'_>) { + let tcx = self.infcx.tcx; + let map = tcx.hir(); + + // Get the closure return value and type. + let body_id = map.body_owned_by(self.mir_def_id()); + let body = &map.body(body_id); + let value = &body.value.peel_blocks(); + let hir::Node::Expr(closure_expr) = tcx.hir_node_by_def_id(self.mir_def_id()) else { + return; + }; + let fn_call_id = tcx.parent_hir_id(self.mir_hir_id()); + let hir::Node::Expr(expr) = tcx.hir_node(fn_call_id) else { return }; + let def_id = map.enclosing_body_owner(fn_call_id); + let tables = tcx.typeck(def_id); + let Some(return_value_ty) = tables.node_type_opt(value.hir_id) else { return }; + let return_value_ty = self.infcx.resolve_vars_if_possible(return_value_ty); + + // We don't use `ty.peel_refs()` to get the number of `*`s needed to get the root type. + let mut ty = return_value_ty; + let mut count = 0; + while let ty::Ref(_, t, _) = ty.kind() { + ty = *t; + count += 1; + } + if !self.infcx.type_is_copy_modulo_regions(self.param_env, ty) { + return; + } + + // Build a new closure where the return type is an owned value, instead of a ref. + let Some(ty::Closure(did, args)) = + tables.node_type_opt(closure_expr.hir_id).as_ref().map(|ty| ty.kind()) + else { + return; + }; + let sig = args.as_closure().sig(); + let closure_sig_as_fn_ptr_ty = Ty::new_fn_ptr( + tcx, + sig.map_bound(|s| { + let unsafety = hir::Unsafety::Normal; + use rustc_target::spec::abi; + tcx.mk_fn_sig( + [s.inputs()[0]], + s.output().peel_refs(), + s.c_variadic, + unsafety, + abi::Abi::Rust, + ) + }), + ); + let parent_args = GenericArgs::identity_for_item( + tcx, + tcx.typeck_root_def_id(self.mir_def_id().to_def_id()), + ); + let closure_kind = args.as_closure().kind(); + let closure_kind_ty = Ty::from_closure_kind(tcx, closure_kind); + let tupled_upvars_ty = self.infcx.next_ty_var(TypeVariableOrigin { + kind: TypeVariableOriginKind::ClosureSynthetic, + span: closure_expr.span, + }); + let closure_args = ty::ClosureArgs::new( + tcx, + ty::ClosureArgsParts { + parent_args, + closure_kind_ty, + closure_sig_as_fn_ptr_ty, + tupled_upvars_ty, + }, + ); + let closure_ty = Ty::new_closure(tcx, *did, closure_args.args); + let closure_ty = tcx.erase_regions(closure_ty); + + let hir::ExprKind::MethodCall(_, rcvr, args, _) = expr.kind else { return }; + let Some(pos) = args + .iter() + .enumerate() + .find(|(_, arg)| arg.hir_id == closure_expr.hir_id) + .map(|(i, _)| i) + else { + return; + }; + // The found `Self` type of the method call. + let Some(possible_rcvr_ty) = tables.node_type_opt(rcvr.hir_id) else { return }; + + // The `MethodCall` expression is `Res::Err`, so we search for the method on the `rcvr_ty`. + let Some(method) = tcx.lookup_method_for_diagnostic((self.mir_def_id(), expr.hir_id)) + else { + return; + }; + + // Get the arguments for the found method, only specifying that `Self` is the receiver type. + let args = GenericArgs::for_item(tcx, method, |param, _| { + if param.index == 0 { + possible_rcvr_ty.into() + } else { + self.infcx.var_for_def(expr.span, param) + } + }); + + let preds = tcx.predicates_of(method).instantiate(tcx, args); + // Get the type for the parameter corresponding to the argument the closure with the + // lifetime error we had. + let Some(input) = tcx + .fn_sig(method) + .instantiate_identity() + .inputs() + .skip_binder() + // Methods have a `self` arg, so `pos` is actually `+ 1` to match the method call arg. + .get(pos + 1) + else { + return; + }; + + let cause = ObligationCause::misc(expr.span, self.mir_def_id()); + + enum CanSuggest { + Yes, + No, + Maybe, + } + + // Ok, the following is a HACK. We go over every predicate in the `fn` looking for the ones + // referencing the argument at hand, which is a closure with some bounds. In those, we + // re-verify that the closure we synthesized still matches the closure bound on the argument + // (this is likely unneeded) but *more importantly*, we look at the + // `::Output = ClosureRetTy` to confirm that the closure type we + // synthesized above *will* be accepted by the `where` bound corresponding to this + // argument. Put a different way, given `counts.iter().max_by_key(|(_, v)| v)`, we check + // that a new `ClosureTy` of `|(_, v)| { **v }` will be accepted by this method signature: + // ``` + // fn max_by_key(self, f: F) -> Option + // where + // Self: Sized, + // F: FnMut(&Self::Item) -> B, + // ``` + // Sadly, we can't use `ObligationCtxt` to do this, we need to modify things in place. + let mut can_suggest = CanSuggest::Maybe; + for pred in preds.predicates { + match tcx.liberate_late_bound_regions(self.mir_def_id().into(), pred.kind()) { + ty::ClauseKind::Trait(pred) + if self.infcx.can_eq(self.param_env, pred.self_ty(), *input) + && [ + tcx.lang_items().fn_trait(), + tcx.lang_items().fn_mut_trait(), + tcx.lang_items().fn_once_trait(), + ] + .contains(&Some(pred.def_id())) => + { + // This predicate is an `Fn*` trait and corresponds to the argument with the + // closure that failed the lifetime check. We verify that the arguments will + // continue to match (which didn't change, so they should, and this be a no-op). + let pred = pred.with_self_ty(tcx, closure_ty); + let o = Obligation::new(tcx, cause.clone(), self.param_env, pred); + if !self.infcx.predicate_may_hold(&o) { + // The closure we have doesn't have the right arguments for the trait bound + can_suggest = CanSuggest::No; + } else if let CanSuggest::Maybe = can_suggest { + // The closure has the right arguments + can_suggest = CanSuggest::Yes; + } + } + ty::ClauseKind::Projection(proj) + if self.infcx.can_eq(self.param_env, proj.projection_ty.self_ty(), *input) + && tcx.lang_items().fn_once_output() == Some(proj.projection_ty.def_id) => + { + // Verify that `<[closure@...] as FnOnce>::Output` matches the expected + // `Output` from the trait bound on the function called with the `[closure@...]` + // as argument. + let proj = proj.with_self_ty(tcx, closure_ty); + let o = Obligation::new(tcx, cause.clone(), self.param_env, proj); + if !self.infcx.predicate_may_hold(&o) { + // Return type doesn't match. + can_suggest = CanSuggest::No; + } else if let CanSuggest::Maybe = can_suggest { + // Return type matches, we can suggest dereferencing the closure's value. + can_suggest = CanSuggest::Yes; + } + } + _ => {} + } + } + if let CanSuggest::Yes = can_suggest { + diag.span_suggestion_verbose( + value.span.shrink_to_lo(), + "dereference the return value", + "*".repeat(count), + Applicability::MachineApplicable, + ); + } + } + #[allow(rustc::diagnostic_outside_of_impl)] #[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable fn suggest_move_on_borrowing_closure(&self, diag: &mut Diag<'_>) { diff --git a/compiler/rustc_hir_typeck/src/lib.rs b/compiler/rustc_hir_typeck/src/lib.rs index 80fd4be53e1..3bf50cb732f 100644 --- a/compiler/rustc_hir_typeck/src/lib.rs +++ b/compiler/rustc_hir_typeck/src/lib.rs @@ -55,7 +55,7 @@ use rustc_data_structures::unord::UnordSet; use rustc_errors::{codes::*, struct_span_code_err, ErrorGuaranteed}; use rustc_hir as hir; use rustc_hir::def::{DefKind, Res}; -use rustc_hir::intravisit::Visitor; +use rustc_hir::intravisit::{Map, Visitor}; use rustc_hir::{HirIdMap, Node}; use rustc_hir_analysis::check::check_abi; use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer; @@ -435,6 +435,28 @@ fn fatally_break_rust(tcx: TyCtxt<'_>, span: Span) -> ! { diag.emit() } +pub fn lookup_method_for_diagnostic<'tcx>( + tcx: TyCtxt<'tcx>, + (def_id, hir_id): (LocalDefId, hir::HirId), +) -> Option { + let root_ctxt = TypeckRootCtxt::new(tcx, def_id); + let param_env = tcx.param_env(def_id); + let fn_ctxt = FnCtxt::new(&root_ctxt, param_env, def_id); + let hir::Node::Expr(expr) = tcx.hir().hir_node(hir_id) else { + return None; + }; + let hir::ExprKind::MethodCall(segment, rcvr, _, _) = expr.kind else { + return None; + }; + let tables = tcx.typeck(def_id); + // The found `Self` type of the method call. + let possible_rcvr_ty = tables.node_type_opt(rcvr.hir_id)?; + fn_ctxt + .lookup_method_for_diagnostic(possible_rcvr_ty, segment, expr.span, expr, rcvr) + .ok() + .map(|method| method.def_id) +} + pub fn provide(providers: &mut Providers) { method::provide(providers); *providers = Providers { @@ -442,6 +464,7 @@ pub fn provide(providers: &mut Providers) { diagnostic_only_typeck, has_typeck_results, used_trait_imports, + lookup_method_for_diagnostic: lookup_method_for_diagnostic, ..*providers }; } diff --git a/compiler/rustc_middle/src/query/keys.rs b/compiler/rustc_middle/src/query/keys.rs index 9cbc4d10146..80d854306a4 100644 --- a/compiler/rustc_middle/src/query/keys.rs +++ b/compiler/rustc_middle/src/query/keys.rs @@ -555,6 +555,19 @@ impl Key for HirId { } } +impl Key for (LocalDefId, HirId) { + type Cache = DefaultCache; + + fn default_span(&self, tcx: TyCtxt<'_>) -> Span { + tcx.hir().span(self.1) + } + + #[inline(always)] + fn key_as_def_id(&self) -> Option { + Some(self.0.into()) + } +} + impl<'tcx> Key for (ValidityRequirement, ty::ParamEnvAnd<'tcx, Ty<'tcx>>) { type Cache = DefaultCache; diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 5e4454db3e2..1866b9490ec 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -989,6 +989,9 @@ rustc_queries! { query diagnostic_only_typeck(key: LocalDefId) -> &'tcx ty::TypeckResults<'tcx> { desc { |tcx| "type-checking `{}`", tcx.def_path_str(key) } } + query lookup_method_for_diagnostic((def_id, hir_id): (LocalDefId, hir::HirId)) -> Option { + desc { |tcx| "lookup_method_for_diagnostics `{}`", tcx.def_path_str(def_id) } + } query used_trait_imports(key: LocalDefId) -> &'tcx UnordSet { desc { |tcx| "finding used_trait_imports `{}`", tcx.def_path_str(key) } diff --git a/tests/ui/closures/return-value-lifetime-error.fixed b/tests/ui/closures/return-value-lifetime-error.fixed new file mode 100644 index 00000000000..bf1f7e4a6cf --- /dev/null +++ b/tests/ui/closures/return-value-lifetime-error.fixed @@ -0,0 +1,16 @@ +//@ run-rustfix +use std::collections::HashMap; + +fn main() { + let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3]; + + let mut counts = HashMap::new(); + for num in vs { + let count = counts.entry(num).or_insert(0); + *count += 1; + } + + let _ = counts.iter().max_by_key(|(_, v)| **v); + //~^ ERROR lifetime may not live long enough + //~| HELP dereference the return value +} diff --git a/tests/ui/closures/return-value-lifetime-error.rs b/tests/ui/closures/return-value-lifetime-error.rs new file mode 100644 index 00000000000..411c91f413e --- /dev/null +++ b/tests/ui/closures/return-value-lifetime-error.rs @@ -0,0 +1,16 @@ +//@ run-rustfix +use std::collections::HashMap; + +fn main() { + let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3]; + + let mut counts = HashMap::new(); + for num in vs { + let count = counts.entry(num).or_insert(0); + *count += 1; + } + + let _ = counts.iter().max_by_key(|(_, v)| v); + //~^ ERROR lifetime may not live long enough + //~| HELP dereference the return value +} diff --git a/tests/ui/closures/return-value-lifetime-error.stderr b/tests/ui/closures/return-value-lifetime-error.stderr new file mode 100644 index 00000000000..a0ad127db28 --- /dev/null +++ b/tests/ui/closures/return-value-lifetime-error.stderr @@ -0,0 +1,16 @@ +error: lifetime may not live long enough + --> $DIR/return-value-lifetime-error.rs:13:47 + | +LL | let _ = counts.iter().max_by_key(|(_, v)| v); + | ------- ^ returning this value requires that `'1` must outlive `'2` + | | | + | | return type of closure is &'2 &i32 + | has type `&'1 (&i32, &i32)` + | +help: dereference the return value + | +LL | let _ = counts.iter().max_by_key(|(_, v)| **v); + | ++ + +error: aborting due to 1 previous error +