#include "polly/Support/SCEVValidator.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/RegionInfo.h" #include using namespace llvm; namespace SCEVType { enum TYPE {INT, PARAM, IV, INVALID}; } struct ValidatorResult { SCEVType::TYPE type; std::vector Parameters; ValidatorResult() : type(SCEVType::INVALID) {}; ValidatorResult(const ValidatorResult &vres) { type = vres.type; Parameters = vres.Parameters; }; ValidatorResult(SCEVType::TYPE type) : type(type) {}; ValidatorResult(SCEVType::TYPE type, const SCEV* Expr) : type(type) { Parameters.push_back(Expr); }; bool isConstant() { return type == SCEVType::INT || type == SCEVType::PARAM; } bool isValid() { return type != SCEVType::INVALID; } bool isIV() { return type == SCEVType::IV; } bool isINT() { return type == SCEVType::INT; } void addParamsFrom(struct ValidatorResult &Source) { Parameters.insert(Parameters.end(), Source.Parameters.begin(), Source.Parameters.end()); } }; /// Check if a SCEV is valid in a SCoP. struct SCEVValidator : public SCEVVisitor { private: const Region *R; ScalarEvolution &SE; const Value *BaseAddress; public: SCEVValidator(const Region *R, ScalarEvolution &SE, const Value *BaseAddress) : R(R), SE(SE), BaseAddress(BaseAddress) {}; struct ValidatorResult visitConstant(const SCEVConstant *Constant) { return ValidatorResult(SCEVType::INT); } struct ValidatorResult visitTruncateExpr(const SCEVTruncateExpr* Expr) { ValidatorResult Op = visit(Expr->getOperand()); // We currently do not represent a truncate expression as an affine // expression. If it is constant during Scop execution, we treat it as a // parameter, otherwise we bail out. if (Op.isConstant()) return ValidatorResult(SCEVType::PARAM, Expr); return ValidatorResult (SCEVType::INVALID); } struct ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr) { ValidatorResult Op = visit(Expr->getOperand()); // We currently do not represent a zero extend expression as an affine // expression. If it is constant during Scop execution, we treat it as a // parameter, otherwise we bail out. if (Op.isConstant()) return ValidatorResult (SCEVType::PARAM, Expr); return ValidatorResult(SCEVType::INVALID); } struct ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr* Expr) { // We currently allow only signed SCEV expressions. In the case of a // signed value, a sign extend is a noop. // // TODO: Reconsider this when we add support for unsigned values. return visit(Expr->getOperand()); } struct ValidatorResult visitAddExpr(const SCEVAddExpr* Expr) { ValidatorResult Return(SCEVType::INT); for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { ValidatorResult Op = visit(Expr->getOperand(i)); if (!Op.isValid()) return ValidatorResult(SCEVType::INVALID); Return.type = std::max(Return.type, Op.type); Return.addParamsFrom(Op); } // TODO: Check for NSW and NUW. return Return; } struct ValidatorResult visitMulExpr(const SCEVMulExpr* Expr) { ValidatorResult Return(SCEVType::INT); for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { ValidatorResult Op = visit(Expr->getOperand(i)); if (Op.type == SCEVType::INT) continue; if (Op.type == SCEVType::INVALID || Return.type != SCEVType::INT) return ValidatorResult(SCEVType::INVALID); Return.type = Op.type; Return.addParamsFrom(Op); } // TODO: Check for NSW and NUW. return Return; } struct ValidatorResult visitUDivExpr(const SCEVUDivExpr* Expr) { ValidatorResult LHS = visit(Expr->getLHS()); ValidatorResult RHS = visit(Expr->getRHS()); // We currently do not represent a unsigned devision as an affine // expression. If the division is constant during Scop execution we treat it // as a parameter, otherwise we bail out. if (LHS.isConstant() && RHS.isConstant()) return ValidatorResult(SCEVType::PARAM, Expr); return ValidatorResult(SCEVType::INVALID); } struct ValidatorResult visitAddRecExpr(const SCEVAddRecExpr* Expr) { if (!Expr->isAffine()) return ValidatorResult(SCEVType::INVALID); ValidatorResult Start = visit(Expr->getStart()); ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); if (!Start.isValid() || !Recurrence.isValid() || Recurrence.isIV()) return ValidatorResult(SCEVType::INVALID); if (!R->contains(Expr->getLoop())) { if (Start.isIV()) return ValidatorResult(SCEVType::INVALID); else return ValidatorResult(SCEVType::PARAM, Expr); } if (!Recurrence.isINT()) return ValidatorResult(SCEVType::INVALID); ValidatorResult Result(SCEVType::IV); Result.addParamsFrom(Start); return Result; } struct ValidatorResult visitSMaxExpr(const SCEVSMaxExpr* Expr) { ValidatorResult Return(SCEVType::INT); for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { ValidatorResult Op = visit(Expr->getOperand(i)); if (!Op.isValid()) return ValidatorResult(SCEVType::INVALID); Return.type = std::max(Return.type, Op.type); Return.addParamsFrom(Op); } return Return; } struct ValidatorResult visitUMaxExpr(const SCEVUMaxExpr* Expr) { ValidatorResult Return(SCEVType::PARAM); // We do not support unsigned operations. If 'Expr' is constant during Scop // execution we treat this as a parameter, otherwise we bail out. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { ValidatorResult Op = visit(Expr->getOperand(i)); if (!Op.isConstant()) return ValidatorResult(SCEVType::INVALID); Return.addParamsFrom(Op); } return Return; } ValidatorResult visitUnknown(const SCEVUnknown* Expr) { Value *V = Expr->getValue(); if (isa(V)) return ValidatorResult(SCEVType::INVALID); if (Instruction *I = dyn_cast(Expr->getValue())) if (R->contains(I)) return ValidatorResult(SCEVType::INVALID); if (BaseAddress == V) return ValidatorResult(SCEVType::INVALID); return ValidatorResult(SCEVType::PARAM, Expr); } }; namespace polly { bool isAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE, const Value *BaseAddress) { if (isa(Expr)) return false; SCEVValidator Validator(R, SE, BaseAddress); ValidatorResult Result = Validator.visit(Expr); return Result.isValid(); } std::vector getParamsInAffineExpr(const Region *R, const SCEV *Expr, ScalarEvolution &SE, const Value *BaseAddress) { if (isa(Expr)) return std::vector(); SCEVValidator Validator(R, SE, BaseAddress); ValidatorResult Result = Validator.visit(Expr); return Result.Parameters; } }