diff --git a/src/utils.rs b/src/utils.rs index 4b144452063..c41c0a8681b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,20 +1,20 @@ -use rustc::lint::*; -use rustc_front::hir::*; +use consts::constant; use reexport::*; -use syntax::codemap::{ExpnInfo, Span, ExpnFormat}; use rustc::front::map::Node::*; +use rustc::lint::*; use rustc::middle::def_id::DefId; -use rustc::middle::ty; +use rustc::middle::{cstore, def, infer, ty, traits}; +use rustc::session::Session; +use rustc_front::hir::*; use std::borrow::Cow; +use std::mem; +use std::ops::{Deref, DerefMut}; +use std::str::FromStr; use syntax::ast::Lit_::*; use syntax::ast; +use syntax::codemap::{ExpnInfo, Span, ExpnFormat}; use syntax::errors::DiagnosticBuilder; use syntax::ptr::P; -use consts::constant; - -use rustc::session::Session; -use std::str::FromStr; -use std::ops::{Deref, DerefMut}; pub type MethodArgs = HirVec>; @@ -23,6 +23,7 @@ pub const BEGIN_UNWIND: [&'static str; 3] = ["std", "rt", "begin_unwind"]; pub const BTREEMAP_PATH: [&'static str; 4] = ["collections", "btree", "map", "BTreeMap"]; pub const CLONE_PATH: [&'static str; 2] = ["Clone", "clone"]; pub const COW_PATH: [&'static str; 3] = ["collections", "borrow", "Cow"]; +pub const DEFAULT_TRAIT_PATH: [&'static str; 3] = ["core", "default", "Default"]; pub const HASHMAP_PATH: [&'static str; 5] = ["std", "collections", "hash", "map", "HashMap"]; pub const LL_PATH: [&'static str; 3] = ["collections", "linked_list", "LinkedList"]; pub const MUTEX_PATH: [&'static str; 4] = ["std", "sync", "mutex", "Mutex"]; @@ -132,7 +133,7 @@ pub fn match_type(cx: &LateContext, ty: ty::Ty, path: &[&str]) -> bool { } } -/// Check if the method call given in `expr` belongs to given trait. +/// Check if the method call given in `expr` belongs to given type. pub fn match_impl_method(cx: &LateContext, expr: &Expr, path: &[&str]) -> bool { let method_call = ty::MethodCall::expr(expr.id); @@ -186,6 +187,73 @@ pub fn match_path_ast(path: &ast::Path, segments: &[&str]) -> bool { path.segments.iter().rev().zip(segments.iter().rev()).all(|(a, b)| a.identifier.name.as_str() == *b) } +/// Get the definition associated to a path. +/// TODO: investigate if there is something more efficient for that. +pub fn path_to_def(cx: &LateContext, path: &[&str]) -> Option { + let cstore = &cx.tcx.sess.cstore; + + let crates = cstore.crates(); + let krate = crates.iter().find(|&&krate| cstore.crate_name(krate) == path[0]); + if let Some(krate) = krate { + let mut items = cstore.crate_top_level_items(*krate); + let mut path_it = path.iter().skip(1).peekable(); + + loop { + let segment = match path_it.next() { + Some(segment) => segment, + None => return None + }; + + for item in &mem::replace(&mut items, vec![]) { + if item.name.as_str() == *segment { + if path_it.peek().is_none() { + return Some(item.def); + } + + let def_id = match item.def { + cstore::DefLike::DlDef(def) => def.def_id(), + cstore::DefLike::DlImpl(def_id) => def_id, + _ => panic!("Unexpected {:?}", item.def), + }; + + items = cstore.item_children(def_id); + break; + } + } + } + } + else { + None + } +} + +/// Convenience function to get the `DefId` of a trait by path. +pub fn get_trait_def_id(cx: &LateContext, path: &[&str]) -> Option { + let def = match path_to_def(cx, path) { + Some(def) => def, + None => return None, + }; + + match def { + cstore::DlDef(def::DefTrait(trait_id)) => Some(trait_id), + _ => None, + } +} + +/// Check whether a type implements a trait. +/// See also `get_trait_def_id`. +pub fn implements_trait<'a, 'tcx>(cx: &LateContext<'a, 'tcx>, ty: ty::Ty<'tcx>, trait_id: DefId) -> bool { + cx.tcx.populate_implementations_for_trait_if_necessary(trait_id); + + let infcx = infer::new_infer_ctxt(cx.tcx, &cx.tcx.tables, None, true); + let obligation = traits::predicate_for_trait_def(cx.tcx, + traits::ObligationCause::dummy(), + trait_id, 0, ty, + vec![]); + + traits::SelectionContext::new(&infcx).evaluate_obligation_conservatively(&obligation) +} + /// Match an `Expr` against a chain of methods, and return the matched `Expr`s. /// /// For example, if `expr` represents the `.baz()` in `foo.bar().baz()`,