Get rid of the redundant elaboration in middle

This commit is contained in:
Michael Goulet 2024-07-06 12:33:03 -04:00
parent 90423a7abb
commit 66eb346770
7 changed files with 40 additions and 100 deletions

View File

@ -7,7 +7,6 @@ pub mod select;
pub mod solve;
pub mod specialization_graph;
mod structural_impls;
pub mod util;
use crate::mir::ConstraintCategory;
use crate::ty::abstract_const::NotConstEvaluatable;

View File

@ -1,62 +0,0 @@
use rustc_data_structures::fx::FxHashSet;
use crate::ty::{Clause, PolyTraitRef, ToPolyTraitRef, TyCtxt, Upcast};
/// Given a [`PolyTraitRef`], get the [`Clause`]s implied by the trait's definition.
///
/// This only exists in `rustc_middle` because the more powerful elaborator depends on
/// `rustc_infer` for elaborating outlives bounds -- this should only be used for pretty
/// printing.
pub fn super_predicates_for_pretty_printing<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: PolyTraitRef<'tcx>,
) -> impl Iterator<Item = Clause<'tcx>> {
let clause = trait_ref.upcast(tcx);
Elaborator { tcx, visited: FxHashSet::from_iter([clause]), stack: vec![clause] }
}
/// Like [`super_predicates_for_pretty_printing`], except it only returns traits and filters out
/// all other [`Clause`]s.
pub fn supertraits_for_pretty_printing<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: PolyTraitRef<'tcx>,
) -> impl Iterator<Item = PolyTraitRef<'tcx>> {
super_predicates_for_pretty_printing(tcx, trait_ref).filter_map(|clause| {
clause.as_trait_clause().map(|trait_clause| trait_clause.to_poly_trait_ref())
})
}
struct Elaborator<'tcx> {
tcx: TyCtxt<'tcx>,
visited: FxHashSet<Clause<'tcx>>,
stack: Vec<Clause<'tcx>>,
}
impl<'tcx> Elaborator<'tcx> {
fn elaborate(&mut self, trait_ref: PolyTraitRef<'tcx>) {
let super_predicates =
self.tcx.explicit_super_predicates_of(trait_ref.def_id()).predicates.iter().filter_map(
|&(pred, _)| {
let clause = pred.instantiate_supertrait(self.tcx, trait_ref);
self.visited.insert(clause).then_some(clause)
},
);
self.stack.extend(super_predicates);
}
}
impl<'tcx> Iterator for Elaborator<'tcx> {
type Item = Clause<'tcx>;
fn next(&mut self) -> Option<Clause<'tcx>> {
if let Some(clause) = self.stack.pop() {
if let Some(trait_clause) = clause.as_trait_clause() {
self.elaborate(trait_clause.to_poly_trait_ref());
}
Some(clause)
} else {
None
}
}
}

View File

@ -37,7 +37,7 @@ use crate::ty::{GenericArg, GenericArgs, GenericArgsRef};
use rustc_ast::{self as ast, attr};
use rustc_data_structures::defer;
use rustc_data_structures::fingerprint::Fingerprint;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::intern::Interned;
use rustc_data_structures::profiling::SelfProfilerRef;
use rustc_data_structures::sharded::{IntoPointer, ShardedHashMap};
@ -532,10 +532,6 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
self.trait_def(trait_def_id).implement_via_object
}
fn supertrait_def_ids(self, trait_def_id: DefId) -> impl IntoIterator<Item = DefId> {
self.supertrait_def_ids(trait_def_id)
}
fn delay_bug(self, msg: impl ToString) -> ErrorGuaranteed {
self.dcx().span_delayed_bug(DUMMY_SP, msg.to_string())
}
@ -2495,25 +2491,7 @@ impl<'tcx> TyCtxt<'tcx> {
/// to identify which traits may define a given associated type to help avoid cycle errors,
/// and to make size estimates for vtable layout computation.
pub fn supertrait_def_ids(self, trait_def_id: DefId) -> impl Iterator<Item = DefId> + 'tcx {
let mut set = FxHashSet::default();
let mut stack = vec![trait_def_id];
set.insert(trait_def_id);
iter::from_fn(move || -> Option<DefId> {
let trait_did = stack.pop()?;
let generic_predicates = self.explicit_super_predicates_of(trait_did);
for (predicate, _) in generic_predicates.predicates {
if let ty::ClauseKind::Trait(data) = predicate.kind().skip_binder() {
if set.insert(data.def_id()) {
stack.push(data.def_id());
}
}
}
Some(trait_did)
})
rustc_type_ir::elaborate::supertrait_def_ids(self, trait_def_id)
}
/// Given a closure signature, returns an equivalent fn signature. Detuples

View File

@ -1,7 +1,6 @@
use crate::mir::interpret::{AllocRange, GlobalAlloc, Pointer, Provenance, Scalar};
use crate::query::IntoQueryParam;
use crate::query::Providers;
use crate::traits::util::{super_predicates_for_pretty_printing, supertraits_for_pretty_printing};
use crate::ty::GenericArgKind;
use crate::ty::{
ConstInt, Expr, ParamConst, ScalarInt, Term, TermKind, TypeFoldable, TypeSuperFoldable,
@ -23,6 +22,7 @@ use rustc_span::symbol::{kw, Ident, Symbol};
use rustc_span::FileNameDisplayPreference;
use rustc_target::abi::Size;
use rustc_target::spec::abi::Abi;
use rustc_type_ir::{elaborate, Upcast as _};
use smallvec::SmallVec;
use std::cell::Cell;
@ -1255,14 +1255,14 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
entry.has_fn_once = true;
return;
} else if self.tcx().is_lang_item(trait_def_id, LangItem::FnMut) {
let super_trait_ref = supertraits_for_pretty_printing(self.tcx(), trait_ref)
let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
.find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
.unwrap();
fn_traits.entry(super_trait_ref).or_default().fn_mut_trait_ref = Some(trait_ref);
return;
} else if self.tcx().is_lang_item(trait_def_id, LangItem::Fn) {
let super_trait_ref = supertraits_for_pretty_printing(self.tcx(), trait_ref)
let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
.find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
.unwrap();
@ -1343,10 +1343,11 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
let bound_principal_with_self = bound_principal
.with_self_ty(cx.tcx(), cx.tcx().types.trait_object_dummy_self);
let super_projections: Vec<_> =
super_predicates_for_pretty_printing(cx.tcx(), bound_principal_with_self)
.filter_map(|clause| clause.as_projection_clause())
.collect();
let clause: ty::Clause<'tcx> = bound_principal_with_self.upcast(cx.tcx());
let super_projections: Vec<_> = elaborate::elaborate(cx.tcx(), [clause])
.filter_only_self()
.filter_map(|clause| clause.as_projection_clause())
.collect();
let mut projections: Vec<_> = predicates
.projection_bounds()

View File

@ -6,7 +6,7 @@ use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::visit::TypeVisitableExt as _;
use rustc_type_ir::{self as ty, Interner, TraitPredicate, Upcast as _};
use rustc_type_ir::{self as ty, elaborate, Interner, TraitPredicate, Upcast as _};
use tracing::{instrument, trace};
use crate::delegate::SolverDelegate;
@ -862,8 +862,7 @@ where
.auto_traits()
.into_iter()
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
self.cx()
.supertrait_def_ids(principal_def_id)
elaborate::supertrait_def_ids(self.cx(), principal_def_id)
.into_iter()
.filter(|def_id| self.cx().trait_is_auto(*def_id))
}))

View File

@ -229,6 +229,34 @@ impl<I: Interner, O: Elaboratable<I>> Iterator for Elaborator<I, O> {
// Supertrait iterator
///////////////////////////////////////////////////////////////////////////
/// Computes the def-ids of the transitive supertraits of `trait_def_id`. This (intentionally)
/// does not compute the full elaborated super-predicates but just the set of def-ids. It is used
/// to identify which traits may define a given associated type to help avoid cycle errors,
/// and to make size estimates for vtable layout computation.
pub fn supertrait_def_ids<I: Interner>(
cx: I,
trait_def_id: I::DefId,
) -> impl Iterator<Item = I::DefId> {
let mut set = HashSet::default();
let mut stack = vec![trait_def_id];
set.insert(trait_def_id);
std::iter::from_fn(move || {
let trait_def_id = stack.pop()?;
for (predicate, _) in cx.explicit_super_predicates_of(trait_def_id).iter_identity() {
if let ty::ClauseKind::Trait(data) = predicate.kind().skip_binder() {
if set.insert(data.def_id()) {
stack.push(data.def_id());
}
}
}
Some(trait_def_id)
})
}
pub fn supertraits<I: Interner>(
tcx: I,
trait_ref: ty::Binder<I, ty::TraitRef<I>>,

View File

@ -253,9 +253,6 @@ pub trait Interner:
fn trait_may_be_implemented_via_object(self, trait_def_id: Self::DefId) -> bool;
fn supertrait_def_ids(self, trait_def_id: Self::DefId)
-> impl IntoIterator<Item = Self::DefId>;
fn delay_bug(self, msg: impl ToString) -> Self::ErrorGuaranteed;
fn is_general_coroutine(self, coroutine_def_id: Self::DefId) -> bool;