Limit number of tracked places, and some other perf improvements

This commit is contained in:
Jannis Christopher Köhl 2022-10-25 21:54:39 +02:00
parent da4a40f816
commit 630e17d3e4
2 changed files with 66 additions and 14 deletions

View File

@ -47,7 +47,7 @@
use std::fmt::{Debug, Formatter};
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_index::vec::IndexVec;
use rustc_middle::mir::tcx::PlaceTy;
use rustc_middle::mir::visit::{PlaceContext, Visitor};
@ -405,12 +405,31 @@ rustc_index::newtype_index!(
);
/// See [`State`].
#[derive(PartialEq, Eq, Clone, Debug)]
#[derive(PartialEq, Eq, Debug)]
enum StateData<V> {
Reachable(IndexVec<ValueIndex, V>),
Unreachable,
}
impl<V: Clone> Clone for StateData<V> {
fn clone(&self) -> Self {
match self {
Self::Reachable(x) => Self::Reachable(x.clone()),
Self::Unreachable => Self::Unreachable,
}
}
fn clone_from(&mut self, source: &Self) {
match (&mut *self, source) {
(Self::Reachable(x), Self::Reachable(y)) => {
// We go through `raw` here, because `IndexVec` currently has a naive `clone_from`.
x.raw.clone_from(&y.raw);
}
_ => *self = source.clone(),
}
}
}
/// The dataflow state for an instance of [`ValueAnalysis`].
///
/// Every instance specifies a lattice that represents the possible values of a single tracked
@ -421,9 +440,19 @@ enum StateData<V> {
/// reachable state). All operations on unreachable states are ignored.
///
/// Flooding means assigning a value (by default ``) to all tracked projections of a given place.
#[derive(PartialEq, Eq, Clone, Debug)]
#[derive(PartialEq, Eq, Debug)]
pub struct State<V>(StateData<V>);
impl<V: Clone> Clone for State<V> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
fn clone_from(&mut self, source: &Self) {
self.0.clone_from(&source.0);
}
}
impl<V: Clone + HasTop + HasBottom> State<V> {
pub fn is_reachable(&self) -> bool {
matches!(&self.0, StateData::Reachable(_))
@ -590,6 +619,7 @@ impl Map {
///
/// This is currently the only way to create a [`Map`]. The way in which the tracked places are
/// chosen is an implementation detail and may not be relied upon.
#[instrument(skip_all, level = "debug")]
pub fn from_filter<'tcx>(
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
@ -604,11 +634,12 @@ impl Map {
if tcx.sess.opts.unstable_opts.unsound_mir_opts {
// We might want to add additional limitations. If a struct has 10 boxed fields of
// itself, there will currently be `10.pow(max_derefs)` tracked places.
map.register_with_filter(tcx, body, 2, filter, &[]);
map.register_with_filter(tcx, body, 2, filter, &FxHashSet::default());
} else {
map.register_with_filter(tcx, body, 0, filter, &escaped_places(body));
}
debug!("registered {} places ({} nodes in total)", map.value_count, map.places.len());
map
}
@ -619,7 +650,7 @@ impl Map {
body: &Body<'tcx>,
max_derefs: u32,
mut filter: impl FnMut(Ty<'tcx>) -> bool,
exclude: &[Place<'tcx>],
exclude: &FxHashSet<Place<'tcx>>,
) {
// This is used to tell whether a type is `!Freeze`.
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
@ -648,10 +679,10 @@ impl Map {
ty: Ty<'tcx>,
filter: &mut impl FnMut(Ty<'tcx>) -> bool,
param_env: ty::ParamEnv<'tcx>,
exclude: &[Place<'tcx>],
exclude: &FxHashSet<Place<'tcx>>,
) {
// This currently does a linear scan, could be improved.
if exclude.contains(&Place { local, projection: tcx.intern_place_elems(projection) }) {
// This will also exclude all projections of the excluded place.
return;
}
@ -764,6 +795,10 @@ impl Map {
Ok(())
}
pub fn tracked_places(&self) -> usize {
self.value_count
}
pub fn apply(&self, place: PlaceIndex, elem: TrackElem) -> Option<PlaceIndex> {
self.projections.get(&(place, elem)).copied()
}
@ -929,20 +964,20 @@ fn iter_fields<'tcx>(
/// Returns all places, that have their reference or address taken.
///
/// This includes shared references.
fn escaped_places<'tcx>(body: &Body<'tcx>) -> Vec<Place<'tcx>> {
fn escaped_places<'tcx>(body: &Body<'tcx>) -> FxHashSet<Place<'tcx>> {
struct Collector<'tcx> {
result: Vec<Place<'tcx>>,
result: FxHashSet<Place<'tcx>>,
}
impl<'tcx> Visitor<'tcx> for Collector<'tcx> {
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
if context.is_borrow() || context.is_address_of() {
self.result.push(*place);
self.result.insert(*place);
}
}
}
let mut collector = Collector { result: Vec::new() };
let mut collector = Collector { result: FxHashSet::default() };
collector.visit_body(body);
collector.result
}

View File

@ -15,6 +15,8 @@ use rustc_span::DUMMY_SP;
use crate::MirPass;
const TRACKING_LIMIT: usize = 1000;
pub struct DataflowConstProp;
impl<'tcx> MirPass<'tcx> for DataflowConstProp {
@ -22,18 +24,33 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
sess.mir_opt_level() >= 1
}
#[instrument(skip_all level = "debug")]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// Decide which places to track during the analysis.
let map = Map::from_filter(tcx, body, Ty::is_scalar);
// We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
// Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
// applications, where `h` is the height of the lattice. Because the height of our lattice
// is linear w.r.t. the number of tracked places, this is `O(tracked_places * n)`. However,
// because every transfer function application could traverse the whole map, this becomes
// `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of
// map nodes is strongly correlated to the number of tracked places, this becomes more or
// less `O(n)` if we place a constant limit on the number of tracked places.
if map.tracked_places() > TRACKING_LIMIT {
debug!("aborted dataflow const prop due to too many tracked places");
return;
}
// Perform the actual dataflow analysis.
let analysis = ConstAnalysis::new(tcx, body, map);
let results = analysis.wrap().into_engine(tcx, body).iterate_to_fixpoint();
let results = debug_span!("analyze")
.in_scope(|| analysis.wrap().into_engine(tcx, body).iterate_to_fixpoint());
// Collect results and patch the body afterwards.
let mut visitor = CollectAndPatch::new(tcx, &results.analysis.0.map);
results.visit_reachable_with(body, &mut visitor);
visitor.visit_body(body);
debug_span!("collect").in_scope(|| results.visit_reachable_with(body, &mut visitor));
debug_span!("patch").in_scope(|| visitor.visit_body(body));
}
}