Add utility functions to check for trait impl

This commit is contained in:
mcarton 2016-01-18 13:10:13 +01:00
parent 90cbc858e9
commit fb6b3bed0f
1 changed files with 78 additions and 10 deletions

View File

@ -1,20 +1,20 @@
use rustc::lint::*;
use rustc_front::hir::*;
use consts::constant;
use reexport::*;
use syntax::codemap::{ExpnInfo, Span, ExpnFormat};
use rustc::front::map::Node::*;
use rustc::lint::*;
use rustc::middle::def_id::DefId;
use rustc::middle::ty;
use rustc::middle::{cstore, def, infer, ty, traits};
use rustc::session::Session;
use rustc_front::hir::*;
use std::borrow::Cow;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::str::FromStr;
use syntax::ast::Lit_::*;
use syntax::ast;
use syntax::codemap::{ExpnInfo, Span, ExpnFormat};
use syntax::errors::DiagnosticBuilder;
use syntax::ptr::P;
use consts::constant;
use rustc::session::Session;
use std::str::FromStr;
use std::ops::{Deref, DerefMut};
pub type MethodArgs = HirVec<P<Expr>>;
@ -23,6 +23,7 @@ pub const BEGIN_UNWIND: [&'static str; 3] = ["std", "rt", "begin_unwind"];
pub const BTREEMAP_PATH: [&'static str; 4] = ["collections", "btree", "map", "BTreeMap"];
pub const CLONE_PATH: [&'static str; 2] = ["Clone", "clone"];
pub const COW_PATH: [&'static str; 3] = ["collections", "borrow", "Cow"];
pub const DEFAULT_TRAIT_PATH: [&'static str; 3] = ["core", "default", "Default"];
pub const HASHMAP_PATH: [&'static str; 5] = ["std", "collections", "hash", "map", "HashMap"];
pub const LL_PATH: [&'static str; 3] = ["collections", "linked_list", "LinkedList"];
pub const MUTEX_PATH: [&'static str; 4] = ["std", "sync", "mutex", "Mutex"];
@ -132,7 +133,7 @@ pub fn match_type(cx: &LateContext, ty: ty::Ty, path: &[&str]) -> bool {
}
}
/// Check if the method call given in `expr` belongs to given trait.
/// Check if the method call given in `expr` belongs to given type.
pub fn match_impl_method(cx: &LateContext, expr: &Expr, path: &[&str]) -> bool {
let method_call = ty::MethodCall::expr(expr.id);
@ -186,6 +187,73 @@ pub fn match_path_ast(path: &ast::Path, segments: &[&str]) -> bool {
path.segments.iter().rev().zip(segments.iter().rev()).all(|(a, b)| a.identifier.name.as_str() == *b)
}
/// Get the definition associated to a path.
/// TODO: investigate if there is something more efficient for that.
pub fn path_to_def(cx: &LateContext, path: &[&str]) -> Option<cstore::DefLike> {
let cstore = &cx.tcx.sess.cstore;
let crates = cstore.crates();
let krate = crates.iter().find(|&&krate| cstore.crate_name(krate) == path[0]);
if let Some(krate) = krate {
let mut items = cstore.crate_top_level_items(*krate);
let mut path_it = path.iter().skip(1).peekable();
loop {
let segment = match path_it.next() {
Some(segment) => segment,
None => return None
};
for item in &mem::replace(&mut items, vec![]) {
if item.name.as_str() == *segment {
if path_it.peek().is_none() {
return Some(item.def);
}
let def_id = match item.def {
cstore::DefLike::DlDef(def) => def.def_id(),
cstore::DefLike::DlImpl(def_id) => def_id,
_ => panic!("Unexpected {:?}", item.def),
};
items = cstore.item_children(def_id);
break;
}
}
}
}
else {
None
}
}
/// Convenience function to get the `DefId` of a trait by path.
pub fn get_trait_def_id(cx: &LateContext, path: &[&str]) -> Option<DefId> {
let def = match path_to_def(cx, path) {
Some(def) => def,
None => return None,
};
match def {
cstore::DlDef(def::DefTrait(trait_id)) => Some(trait_id),
_ => None,
}
}
/// Check whether a type implements a trait.
/// See also `get_trait_def_id`.
pub fn implements_trait<'a, 'tcx>(cx: &LateContext<'a, 'tcx>, ty: ty::Ty<'tcx>, trait_id: DefId) -> bool {
cx.tcx.populate_implementations_for_trait_if_necessary(trait_id);
let infcx = infer::new_infer_ctxt(cx.tcx, &cx.tcx.tables, None, true);
let obligation = traits::predicate_for_trait_def(cx.tcx,
traits::ObligationCause::dummy(),
trait_id, 0, ty,
vec![]);
traits::SelectionContext::new(&infcx).evaluate_obligation_conservatively(&obligation)
}
/// Match an `Expr` against a chain of methods, and return the matched `Expr`s.
///
/// For example, if `expr` represents the `.baz()` in `foo.bar().baz()`,