rewrite stack dependent overflow handling

This commit is contained in:
lcnr 2023-08-03 14:43:26 +02:00
parent a745cbb042
commit 8aca388af8
5 changed files with 286 additions and 179 deletions

View File

@ -1,7 +1,6 @@
use std::ops::ControlFlow;
use rustc_data_structures::intern::Interned;
use rustc_query_system::cache::Cache;
use crate::infer::canonical::{CanonicalVarValues, QueryRegionConstraints};
use crate::traits::query::NoSolution;
@ -11,9 +10,10 @@ use crate::ty::{
TypeVisitor,
};
mod cache;
pub mod inspect;
pub type EvaluationCache<'tcx> = Cache<CanonicalInput<'tcx>, QueryResult<'tcx>>;
pub use cache::{CacheData, EvaluationCache};
/// A goal is a statement, i.e. `predicate`, we want to prove
/// given some assumptions, i.e. `param_env`.

View File

@ -0,0 +1,100 @@
use super::{CanonicalInput, QueryResult};
use crate::ty::TyCtxt;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_data_structures::sync::Lock;
use rustc_query_system::cache::WithDepNode;
use rustc_query_system::dep_graph::DepNodeIndex;
use rustc_session::Limit;
/// The trait solver cache used by `-Ztrait-solver=next`.
///
/// FIXME(@lcnr): link to some official documentation of how
/// this works.
#[derive(Default)]
pub struct EvaluationCache<'tcx> {
map: Lock<FxHashMap<CanonicalInput<'tcx>, CacheEntry<'tcx>>>,
}
pub struct CacheData<'tcx> {
pub result: QueryResult<'tcx>,
pub reached_depth: usize,
pub encountered_overflow: bool,
}
impl<'tcx> EvaluationCache<'tcx> {
/// Insert a final result into the global cache.
pub fn insert(
&self,
key: CanonicalInput<'tcx>,
reached_depth: usize,
did_overflow: bool,
cycle_participants: FxHashSet<CanonicalInput<'tcx>>,
dep_node: DepNodeIndex,
result: QueryResult<'tcx>,
) {
let mut map = self.map.borrow_mut();
let entry = map.entry(key).or_default();
let data = WithDepNode::new(dep_node, result);
entry.cycle_participants.extend(cycle_participants);
if did_overflow {
entry.with_overflow.insert(reached_depth, data);
} else {
entry.success = Some(Success { data, reached_depth });
}
}
/// Try to fetch a cached result, checking the recursion limit
/// and handling root goals of coinductive cycles.
///
/// If this returns `Some` the cache result can be used.
pub fn get(
&self,
tcx: TyCtxt<'tcx>,
key: CanonicalInput<'tcx>,
cycle_participant_in_stack: impl FnOnce(&FxHashSet<CanonicalInput<'tcx>>) -> bool,
available_depth: Limit,
) -> Option<CacheData<'tcx>> {
let map = self.map.borrow();
let entry = map.get(&key)?;
if cycle_participant_in_stack(&entry.cycle_participants) {
return None;
}
if let Some(ref success) = entry.success {
if available_depth.value_within_limit(success.reached_depth) {
return Some(CacheData {
result: success.data.get(tcx),
reached_depth: success.reached_depth,
encountered_overflow: false,
});
}
}
entry.with_overflow.get(&available_depth.0).map(|e| CacheData {
result: e.get(tcx),
reached_depth: available_depth.0,
encountered_overflow: true,
})
}
}
struct Success<'tcx> {
data: WithDepNode<QueryResult<'tcx>>,
reached_depth: usize,
}
/// The cache entry for a goal `CanonicalInput`.
///
/// This contains results whose computation never hit the
/// recursion limit in `success`, and all results which hit
/// the recursion limit in `with_overflow`.
#[derive(Default)]
struct CacheEntry<'tcx> {
success: Option<Success<'tcx>>,
/// We have to be careful when caching roots of cycles.
///
/// See the doc comment of `StackEntry::cycle_participants` for more
/// details.
cycle_participants: FxHashSet<CanonicalInput<'tcx>>,
with_overflow: FxHashMap<usize, WithDepNode<QueryResult<'tcx>>>,
}

View File

@ -340,6 +340,7 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
) -> Result<(bool, Certainty, Vec<Goal<'tcx, ty::Predicate<'tcx>>>), NoSolution> {
let (orig_values, canonical_goal) = self.canonicalize_goal(goal);
let mut goal_evaluation = self.inspect.new_goal_evaluation(goal, is_normalizes_to_hack);
let encountered_overflow = self.search_graph.encountered_overflow();
let canonical_response = EvalCtxt::evaluate_canonical_goal(
self.tcx(),
self.search_graph,
@ -388,6 +389,7 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
&& !self.search_graph.in_cycle()
{
debug!("rerunning goal to check result is stable");
self.search_graph.reset_encountered_overflow(encountered_overflow);
let (_orig_values, canonical_goal) = self.canonicalize_goal(goal);
let new_canonical_response = EvalCtxt::evaluate_canonical_goal(
self.tcx(),

View File

@ -1,28 +1,42 @@
mod cache;
mod overflow;
pub(super) use overflow::OverflowHandler;
use rustc_middle::traits::solve::inspect::CacheHit;
use self::cache::ProvisionalEntry;
use cache::ProvisionalCache;
use overflow::OverflowData;
use rustc_index::IndexVec;
use rustc_middle::dep_graph::DepKind;
use rustc_middle::traits::solve::{CanonicalInput, Certainty, EvaluationCache, QueryResult};
use rustc_middle::ty::TyCtxt;
use std::{collections::hash_map::Entry, mem};
use super::inspect::ProofTreeBuilder;
use super::SolverMode;
use cache::ProvisionalCache;
use rustc_data_structures::fx::FxHashSet;
use rustc_index::Idx;
use rustc_index::IndexVec;
use rustc_middle::dep_graph::DepKind;
use rustc_middle::traits::solve::inspect::CacheHit;
use rustc_middle::traits::solve::CacheData;
use rustc_middle::traits::solve::{CanonicalInput, Certainty, EvaluationCache, QueryResult};
use rustc_middle::ty::TyCtxt;
use rustc_session::Limit;
use std::{collections::hash_map::Entry, mem};
rustc_index::newtype_index! {
pub struct StackDepth {}
}
struct StackElem<'tcx> {
#[derive(Debug)]
struct StackEntry<'tcx> {
input: CanonicalInput<'tcx>,
available_depth: Limit,
// The maximum depth reached by this stack entry, only up-to date
// for the top of the stack and lazily updated for the rest.
reached_depth: StackDepth,
encountered_overflow: bool,
has_been_used: bool,
/// We put only the root goal of a coinductive cycle into the global cache.
///
/// If we were to use that result when later trying to prove another cycle
/// participant, we can end up with unstable query results.
///
/// See tests/ui/new-solver/coinduction/incompleteness-unstable-result.rs for
/// an example of where this is needed.
cycle_participants: FxHashSet<CanonicalInput<'tcx>>,
}
pub(super) struct SearchGraph<'tcx> {
@ -31,8 +45,7 @@ pub(super) struct SearchGraph<'tcx> {
/// The stack of goals currently being computed.
///
/// An element is *deeper* in the stack if its index is *lower*.
stack: IndexVec<StackDepth, StackElem<'tcx>>,
overflow_data: OverflowData,
stack: IndexVec<StackDepth, StackEntry<'tcx>>,
provisional_cache: ProvisionalCache<'tcx>,
}
@ -42,7 +55,6 @@ impl<'tcx> SearchGraph<'tcx> {
mode,
local_overflow_limit: tcx.recursion_limit().0.ilog2() as usize,
stack: Default::default(),
overflow_data: OverflowData::new(tcx),
provisional_cache: ProvisionalCache::empty(),
}
}
@ -55,15 +67,34 @@ impl<'tcx> SearchGraph<'tcx> {
self.local_overflow_limit
}
/// We do not use the global cache during coherence.
/// Update the stack and reached depths on cache hits.
#[instrument(level = "debug", skip(self))]
fn on_cache_hit(&mut self, additional_depth: usize, encountered_overflow: bool) {
let reached_depth = self.stack.next_index().plus(additional_depth);
if let Some(last) = self.stack.raw.last_mut() {
last.reached_depth = last.reached_depth.max(reached_depth);
last.encountered_overflow |= encountered_overflow;
}
}
/// Pops the highest goal from the stack, lazily updating the
/// the next goal in the stack.
///
/// Directly popping from the stack instead of using this method
/// would cause us to not track overflow and recursion depth correctly.
fn pop_stack(&mut self) -> StackEntry<'tcx> {
let elem = self.stack.pop().unwrap();
if let Some(last) = self.stack.raw.last_mut() {
last.reached_depth = last.reached_depth.max(elem.reached_depth);
last.encountered_overflow |= elem.encountered_overflow;
}
elem
}
/// The trait solver behavior is different for coherence
/// so we would have to add the solver mode to the cache key.
/// This is probably not worth it as trait solving during
/// coherence tends to already be incredibly fast.
///
/// We could add another global cache for coherence instead,
/// but that's effort so let's only do it if necessary.
/// so we use a separate cache. Alternatively we could use
/// a single cache and share it between coherence and ordinary
/// trait solving.
pub(super) fn global_cache(&self, tcx: TyCtxt<'tcx>) -> &'tcx EvaluationCache<'tcx> {
match self.mode {
SolverMode::Normal => &tcx.new_solver_evaluation_cache,
@ -93,6 +124,47 @@ impl<'tcx> SearchGraph<'tcx> {
}
}
/// Fetches whether the current goal encountered overflow.
///
/// This should only be used for the check in `evaluate_goal`.
pub(super) fn encountered_overflow(&self) -> bool {
if let Some(last) = self.stack.raw.last() { last.encountered_overflow } else { false }
}
/// Resets `encountered_overflow` of the current goal.
///
/// This should only be used for the check in `evaluate_goal`.
pub(super) fn reset_encountered_overflow(&mut self, encountered_overflow: bool) {
if encountered_overflow {
self.stack.raw.last_mut().unwrap().encountered_overflow = true;
}
}
/// Returns the remaining depth allowed for nested goals.
///
/// This is generally simply one less than the current depth.
/// However, if we encountered overflow, we significantly reduce
/// the remaining depth of all nested goals to prevent hangs
/// in case there is exponential blowup.
fn allowed_depth_for_nested(
tcx: TyCtxt<'tcx>,
stack: &IndexVec<StackDepth, StackEntry<'tcx>>,
) -> Option<Limit> {
if let Some(last) = stack.raw.last() {
if last.available_depth.0 == 0 {
return None;
}
Some(if last.encountered_overflow {
Limit(last.available_depth.0 / 4)
} else {
Limit(last.available_depth.0 - 1)
})
} else {
Some(tcx.recursion_limit())
}
}
/// Tries putting the new goal on the stack, returning an error if it is already cached.
///
/// This correctly updates the provisional cache if there is a cycle.
@ -101,18 +173,24 @@ impl<'tcx> SearchGraph<'tcx> {
&mut self,
tcx: TyCtxt<'tcx>,
input: CanonicalInput<'tcx>,
available_depth: Limit,
inspect: &mut ProofTreeBuilder<'tcx>,
) -> Result<(), QueryResult<'tcx>> {
// Look at the provisional cache to check for cycles.
let cache = &mut self.provisional_cache;
match cache.lookup_table.entry(input) {
// No entry, simply push this goal on the stack after dealing with overflow.
// No entry, simply push this goal on the stack.
Entry::Vacant(v) => {
if self.overflow_data.has_overflow(self.stack.len()) {
return Err(self.deal_with_overflow(tcx, input));
}
let depth = self.stack.push(StackElem { input, has_been_used: false });
let depth = self.stack.next_index();
let entry = StackEntry {
input,
available_depth,
reached_depth: depth,
encountered_overflow: false,
has_been_used: false,
cycle_participants: Default::default(),
};
assert_eq!(self.stack.push(entry), depth);
let response = Self::response_no_constraints(tcx, input, Certainty::Yes);
let entry_index = cache.entries.push(ProvisionalEntry { response, depth, input });
v.insert(entry_index);
@ -136,8 +214,12 @@ impl<'tcx> SearchGraph<'tcx> {
debug!("encountered cycle with depth {stack_depth:?}");
cache.add_dependency_of_leaf_on(entry_index);
let mut iter = self.stack.iter_mut();
let root = iter.nth(stack_depth.as_usize()).unwrap();
for e in iter {
root.cycle_participants.insert(e.input);
}
self.stack[stack_depth].has_been_used = true;
// NOTE: The goals on the stack aren't the only goals involved in this cycle.
// We can also depend on goals which aren't part of the stack but coinductively
// depend on the stack themselves. We already checked whether all the goals
@ -148,6 +230,9 @@ impl<'tcx> SearchGraph<'tcx> {
.iter()
.all(|g| g.input.value.goal.predicate.is_coinductive(tcx))
{
// If we're in a coinductive cycle, we have to retry proving the current goal
// until we reach a fixpoint.
self.stack[stack_depth].has_been_used = true;
Err(cache.provisional_result(entry_index))
} else {
Err(Self::response_no_constraints(tcx, input, Certainty::OVERFLOW))
@ -173,13 +258,12 @@ impl<'tcx> SearchGraph<'tcx> {
&mut self,
actual_input: CanonicalInput<'tcx>,
response: QueryResult<'tcx>,
) -> bool {
let stack_elem = self.stack.pop().unwrap();
let StackElem { input, has_been_used } = stack_elem;
assert_eq!(input, actual_input);
) -> Result<StackEntry<'tcx>, ()> {
let stack_entry = self.pop_stack();
assert_eq!(stack_entry.input, actual_input);
let cache = &mut self.provisional_cache;
let provisional_entry_index = *cache.lookup_table.get(&input).unwrap();
let provisional_entry_index = *cache.lookup_table.get(&stack_entry.input).unwrap();
let provisional_entry = &mut cache.entries[provisional_entry_index];
// We eagerly update the response in the cache here. If we have to reevaluate
// this goal we use the new response when hitting a cycle, and we definitely
@ -188,7 +272,7 @@ impl<'tcx> SearchGraph<'tcx> {
// Was the current goal the root of a cycle and was the provisional response
// different from the final one.
if has_been_used && prev_response != response {
if stack_entry.has_been_used && prev_response != response {
// If so, remove all entries whose result depends on this goal
// from the provisional cache...
//
@ -201,29 +285,44 @@ impl<'tcx> SearchGraph<'tcx> {
cache.entries.truncate(provisional_entry_index.index() + 1);
// ...and finally push our goal back on the stack and reevaluate it.
self.stack.push(StackElem { input, has_been_used: false });
false
self.stack.push(StackEntry { has_been_used: false, ..stack_entry });
Err(())
} else {
true
Ok(stack_entry)
}
}
pub(super) fn with_new_goal(
&mut self,
tcx: TyCtxt<'tcx>,
canonical_input: CanonicalInput<'tcx>,
input: CanonicalInput<'tcx>,
inspect: &mut ProofTreeBuilder<'tcx>,
mut loop_body: impl FnMut(&mut Self, &mut ProofTreeBuilder<'tcx>) -> QueryResult<'tcx>,
) -> QueryResult<'tcx> {
let Some(available_depth) = Self::allowed_depth_for_nested(tcx, &self.stack) else {
if let Some(last) = self.stack.raw.last_mut() {
last.encountered_overflow = true;
}
return Self::response_no_constraints(tcx, input, Certainty::OVERFLOW);
};
if inspect.use_global_cache() {
if let Some(result) = self.global_cache(tcx).get(&canonical_input, tcx) {
debug!(?canonical_input, ?result, "cache hit");
inspect.cache_hit(CacheHit::Global);
if let Some(CacheData { result, reached_depth, encountered_overflow }) =
self.global_cache(tcx).get(
tcx,
input,
|cycle_participants| {
self.stack.iter().any(|entry| cycle_participants.contains(&entry.input))
},
available_depth,
)
{
self.on_cache_hit(reached_depth, encountered_overflow);
return result;
}
}
match self.try_push_stack(tcx, canonical_input, inspect) {
match self.try_push_stack(tcx, input, available_depth, inspect) {
Ok(()) => {}
// Our goal is already on the stack, eager return.
Err(response) => return response,
@ -232,59 +331,58 @@ impl<'tcx> SearchGraph<'tcx> {
// This is for global caching, so we properly track query dependencies.
// Everything that affects the `Result` should be performed within this
// `with_anon_task` closure.
let (result, dep_node) = tcx.dep_graph.with_anon_task(tcx, DepKind::TraitSelect, || {
self.repeat_while_none(
|this| {
let result = this.deal_with_overflow(tcx, canonical_input);
let _ = this.stack.pop().unwrap();
result
},
|this| {
let result = loop_body(this, inspect);
this.try_finalize_goal(canonical_input, result).then(|| result)
},
)
});
let ((final_entry, result), dep_node) =
tcx.dep_graph.with_anon_task(tcx, DepKind::TraitSelect, || {
// We run our goal in a loop to handle coinductive cycles. If we fail to reach a
// fipoint we return overflow.
for _ in 0..self.local_overflow_limit() {
let result = loop_body(self, inspect);
if let Ok(final_entry) = self.try_finalize_goal(input, result) {
return (final_entry, result);
}
}
debug!("canonical cycle overflow");
let current_entry = self.pop_stack();
let result = Self::response_no_constraints(tcx, input, Certainty::OVERFLOW);
(current_entry, result)
});
let cache = &mut self.provisional_cache;
let provisional_entry_index = *cache.lookup_table.get(&canonical_input).unwrap();
let provisional_entry_index = *cache.lookup_table.get(&input).unwrap();
let provisional_entry = &mut cache.entries[provisional_entry_index];
let depth = provisional_entry.depth;
// If not, we're done with this goal.
// We're now done with this goal. In case this goal is involved in a cycle
// do not remove it from the provisional cache and do not add it to the global
// cache.
//
// Check whether that this goal doesn't depend on a goal deeper on the stack
// and if so, move it to the global cache.
//
// Note that if any nested goal were to depend on something deeper on the stack,
// this would have also updated the depth of the current goal.
// It is not possible for any nested goal to depend on something deeper on the
// stack, as this would have also updated the depth of the current goal.
if depth == self.stack.next_index() {
// If the current goal is the head of a cycle, we drop all other
// cycle participants without moving them to the global cache.
let other_cycle_participants = provisional_entry_index.index() + 1;
for (i, entry) in cache.entries.drain_enumerated(other_cycle_participants..) {
for (i, entry) in cache.entries.drain_enumerated(provisional_entry_index.index()..) {
let actual_index = cache.lookup_table.remove(&entry.input);
debug_assert_eq!(Some(i), actual_index);
debug_assert!(entry.depth == depth);
}
let current_goal = cache.entries.pop().unwrap();
let actual_index = cache.lookup_table.remove(&current_goal.input);
debug_assert_eq!(Some(provisional_entry_index), actual_index);
debug_assert!(current_goal.depth == depth);
// We move the root goal to the global cache if we either did not hit an overflow or if it's
// the root goal as that will now always hit the same overflow limit.
// When encountering a cycle, both inductive and coinductive, we only
// move the root into the global cache. We also store all other cycle
// participants involved.
//
// NOTE: We cannot move any non-root goals to the global cache. When replaying the root goal's
// dependencies, our non-root goal may no longer appear as child of the root goal.
//
// See https://github.com/rust-lang/rust/pull/108071 for some additional context.
let can_cache = inspect.use_global_cache()
&& (!self.overflow_data.did_overflow() || self.stack.is_empty());
if can_cache {
self.global_cache(tcx).insert(current_goal.input, dep_node, current_goal.response)
}
// We disable the global cache entry of the root goal if a cycle
// participant is on the stack. This is necessary to prevent unstable
// results. See the comment of `StackEntry::cycle_participants` for
// more details.
let reached_depth = final_entry.reached_depth.as_usize() - self.stack.len();
self.global_cache(tcx).insert(
input,
reached_depth,
final_entry.encountered_overflow,
final_entry.cycle_participants,
dep_node,
result,
)
}
result

View File

@ -1,93 +0,0 @@
use rustc_infer::infer::canonical::Canonical;
use rustc_infer::traits::query::NoSolution;
use rustc_middle::traits::solve::{Certainty, QueryResult};
use rustc_middle::ty::TyCtxt;
use rustc_session::Limit;
use super::SearchGraph;
use crate::solve::response_no_constraints_raw;
/// When detecting a solver overflow, we return ambiguity. Overflow can be
/// *hidden* by either a fatal error in an **AND** or a trivial success in an **OR**.
///
/// This is in issue in case of exponential blowup, e.g. if each goal on the stack
/// has multiple nested (overflowing) candidates. To deal with this, we reduce the limit
/// used by the solver when hitting the default limit for the first time.
///
/// FIXME: Get tests where always using the `default_limit` results in a hang and refer
/// to them here. We can also improve the overflow strategy if necessary.
pub(super) struct OverflowData {
default_limit: Limit,
current_limit: Limit,
/// When proving an **AND** we have to repeatedly iterate over the yet unproven goals.
///
/// Because of this each iteration also increases the depth in addition to the stack
/// depth.
additional_depth: usize,
}
impl OverflowData {
pub(super) fn new(tcx: TyCtxt<'_>) -> OverflowData {
let default_limit = tcx.recursion_limit();
OverflowData { default_limit, current_limit: default_limit, additional_depth: 0 }
}
#[inline]
pub(super) fn did_overflow(&self) -> bool {
self.default_limit.0 != self.current_limit.0
}
#[inline]
pub(super) fn has_overflow(&self, depth: usize) -> bool {
!self.current_limit.value_within_limit(depth + self.additional_depth)
}
/// Updating the current limit when hitting overflow.
fn deal_with_overflow(&mut self) {
// When first hitting overflow we reduce the overflow limit
// for all future goals to prevent hangs if there's an exponential
// blowup.
self.current_limit.0 = self.default_limit.0 / 8;
}
}
pub(in crate::solve) trait OverflowHandler<'tcx> {
fn search_graph(&mut self) -> &mut SearchGraph<'tcx>;
fn repeat_while_none<T>(
&mut self,
on_overflow: impl FnOnce(&mut Self) -> Result<T, NoSolution>,
mut loop_body: impl FnMut(&mut Self) -> Option<Result<T, NoSolution>>,
) -> Result<T, NoSolution> {
let start_depth = self.search_graph().overflow_data.additional_depth;
let depth = self.search_graph().stack.len();
while !self.search_graph().overflow_data.has_overflow(depth) {
if let Some(result) = loop_body(self) {
self.search_graph().overflow_data.additional_depth = start_depth;
return result;
}
self.search_graph().overflow_data.additional_depth += 1;
}
self.search_graph().overflow_data.additional_depth = start_depth;
self.search_graph().overflow_data.deal_with_overflow();
on_overflow(self)
}
}
impl<'tcx> OverflowHandler<'tcx> for SearchGraph<'tcx> {
fn search_graph(&mut self) -> &mut SearchGraph<'tcx> {
self
}
}
impl<'tcx> SearchGraph<'tcx> {
pub fn deal_with_overflow(
&mut self,
tcx: TyCtxt<'tcx>,
goal: Canonical<'tcx, impl Sized>,
) -> QueryResult<'tcx> {
self.overflow_data.deal_with_overflow();
Ok(response_no_constraints_raw(tcx, goal.max_universe, goal.variables, Certainty::OVERFLOW))
}
}