Fix backprop

This commit is contained in:
William S. Moses 2021-07-06 21:09:51 -04:00 committed by William Moses
parent 29997892ca
commit 10c886fe43
6 changed files with 219 additions and 26 deletions

View File

@ -13,6 +13,9 @@ include "Dialect.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
// HasParent<"ParallelOp">,
def BarrierOp
: Polygeist_Op<"barrier",
@ -37,4 +40,20 @@ def SubIndexOp : Polygeist_Op<"subindex", [
let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
// Memref2PointerOp
//===----------------------------------------------------------------------===//
def Memref2PointerOp : Polygeist_Op<"memref2pointer", [
DeclareOpInterfaceMethods<ViewLikeOpInterface>, NoSideEffect
]> {
let summary = "memref subview operation";
let arguments = (ins AnyMemRef : $source);
let results = (outs LLVM_AnyPointer : $result);
let hasCanonicalizer = 1;
let hasFolder = 0;
}
#endif // POLYGEIST_OPS

View File

@ -10,6 +10,9 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "polygeist/Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#define GET_OP_CLASSES
#include "polygeist/PolygeistOps.cpp.inc"
@ -262,4 +265,25 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results
.insert<SubIndexOpMemRefCastFolder, SubIndex2, SubToCast, DeallocSubView>(
context);
}
}
Value Memref2PointerOp::getViewSource() { return source(); }
class MemRef2PointerCast final : public OpRewritePattern<Memref2PointerOp> {
public:
using OpRewritePattern<Memref2PointerOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Memref2PointerOp op,
PatternRewriter &rewriter) const override {
auto src = op.source().getDefiningOp<memref::CastOp>();
if (!src)
return failure();
rewriter.replaceOpWithNewOp<polygeist::Memref2PointerOp>(op, op.getType(), src.source());
return success();
}
};
void Memref2PointerOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<MemRef2PointerCast>(context);
}

View File

@ -1422,6 +1422,23 @@ struct SubToAdd : public OpRewritePattern<SubIOp> {
return failure();
}
};
struct ReturnSq : public OpRewritePattern<ReturnOp> {
using OpRewritePattern<ReturnOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReturnOp op,
PatternRewriter &rewriter) const override {
bool changed = false;
SmallVector<Operation*> toErase;
for (auto iter = op->getBlock()->rbegin(); iter != op->getBlock()->rend() && &*iter != op; iter++) {
changed = true;
toErase.push_back(&*iter);
}
for(auto op : toErase) {
rewriter.eraseOp(op);
}
return success(changed);
}
};
void CanonicalizeFor::runOnFunction() {
mlir::RewritePatternSet rpl(getFunction().getContext());
rpl.add<PropagateInLoopBody, DetectTrivialIndVarInArgs,
@ -1429,6 +1446,7 @@ void CanonicalizeFor::runOnFunction() {
MoveWhileDown, MoveWhileDown2, MoveWhileDown3,
MoveWhileInvariantIfResult, WhileLogicalNegation, SubToAdd,
WhileCmpOffset, WhileLICM, RemoveUnusedCondVar,
ReturnSq,
MoveSideEffectFreeWhile>(getFunction().getContext());
GreedyRewriteConfig config;
config.maxIterations = 47;
@ -1438,4 +1456,4 @@ void CanonicalizeFor::runOnFunction() {
std::unique_ptr<Pass> mlir::polygeist::createCanonicalizeForPass() {
return std::make_unique<CanonicalizeFor>();
}
}

View File

@ -1,6 +1,7 @@
#include "clang-mlir.h"
#include "llvm/Support/Debug.h"
#include <clang/AST/Decl.h>
#include <clang/Basic/DiagnosticOptions.h>
#include <clang/Basic/FileManager.h>
#include <clang/Basic/FileSystemOptions.h>
@ -91,7 +92,7 @@ ValueWithOffsets MLIRScanner::VisitDeclStmt(clang::DeclStmt *decl) {
for (auto sub : decl->decls()) {
if (auto vd = dyn_cast<VarDecl>(sub)) {
VisitVarDecl(vd);
} else if (isa<TypeAliasDecl, RecordDecl>(sub)) {
} else if (isa<TypeAliasDecl, RecordDecl, StaticAssertDecl, TypedefDecl>(sub)) {
} else {
llvm::errs() << " + visiting unknonwn sub decl stmt\n";
sub->dump();
@ -1232,7 +1233,18 @@ MLIRScanner::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *expr) {
// Check the RHS has been successfully emitted
assert(rhs);
auto idx = castToIndex(getMLIRLocation(expr->getRBracketLoc()), rhs);
if (isa<clang::VectorType>(expr->getLHS()->getType()->getUnqualifiedDesugaredType())) {
assert(moo.isReference);
moo.isReference = false;
auto mt = moo.val.getType().cast<MemRefType>();
auto shape = std::vector<int64_t>(mt.getShape());
shape.erase(shape.begin());
auto mt0 = mlir::MemRefType::get(shape, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace());
moo.val = builder.create<polygeist::SubIndexOp>(loc, mt0, moo.val,
getConstantIndex(0));
}
return CommonArrayLookup(moo, idx);
}
@ -1287,7 +1299,6 @@ llvm::Type *anonymize(llvm::Type *T) {
const clang::FunctionDecl *MLIRScanner::EmitCallee(const Expr *E) {
E = E->IgnoreParens();
// Look through function-to-pointer decay.
if (auto ICE = dyn_cast<ImplicitCastExpr>(E)) {
if (ICE->getCastKind() == CK_FunctionToPointerDecay ||
@ -1300,6 +1311,7 @@ const clang::FunctionDecl *MLIRScanner::EmitCallee(const Expr *E) {
if (auto FD = dyn_cast<FunctionDecl>(DRE->getDecl())) {
return FD;
}
} else if (auto ME = dyn_cast<MemberExpr>(E)) {
if (auto FD = dyn_cast<FunctionDecl>(ME->getMemberDecl())) {
// TODO EmitIgnoredExpr(ME->getBase());
@ -1454,14 +1466,21 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
}
}
/*
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
if (sr->getDecl()->getIdentifier() &&
sr->getDecl()->getName() == "__shfl_up_sync") {
std::vector<mlir::Value> args;
for (auto a : expr->arguments()) {
args.push_back(Visit(a).getValue(builder));
}
builder.create<gpu::ShuffleOp>(loc, );
assert(0 && "__shfl_up_sync unhandled");
return nullptr;
}
}
*/
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
@ -1476,6 +1495,19 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
/*isReference*/ false);
}
}
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
if (sr->getDecl()->getIdentifier() &&
sr->getDecl()->getName() == "log") {
std::vector<mlir::Value> args;
for (auto a : expr->arguments()) {
args.push_back(Visit(a).getValue(builder));
}
return ValueWithOffsets(
builder.create<mlir::math::LogOp>(loc, args[0]),
/*isReference*/ false);
}
}
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
if (sr->getDecl()->getIdentifier() &&
@ -1607,7 +1639,11 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
assert(sub.isReference);
return sub.val;
}
return sub.getValue(builder);
auto val = sub.getValue(builder);
if (auto mt = val.getType().dyn_cast<MemRefType>()) {
val = builder.create<polygeist::Memref2PointerOp>(loc, LLVM::LLVMPointerType::get(mt.getElementType()), val);
}
return val;
};
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
@ -1743,22 +1779,28 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr()))
if (sr->getDecl()->getIdentifier() &&
(sr->getDecl()->getName() == "cudaMemcpy"))
if (auto BCdst = dyn_cast<clang::CastExpr>(expr->getArg(0)))
if (auto BCsrc = dyn_cast<clang::CastExpr>(expr->getArg(1))) {
auto dst = Visit(BCdst->getSubExpr()).getValue(builder);
auto src = Visit(BCsrc->getSubExpr()).getValue(builder);
auto elem =
(sr->getDecl()->getName() == "cudaMemcpy" || sr->getDecl()->getName() == "memcpy")) {
if (auto BCdst = dyn_cast<clang::CastExpr>(expr->getArg(0))) {
auto elem =
cast<clang::PointerType>(BCdst->getSubExpr()
->getType()
->getUnqualifiedDesugaredType())
->getPointeeType();
bool isArray = false;
Glob.getMLIRType(elem, &isArray);
if (auto BCsrc = dyn_cast<clang::CastExpr>(expr->getArg(1))) {
auto selem =
cast<clang::PointerType>(BCsrc->getSubExpr()
->getType()
->getUnqualifiedDesugaredType())
->getPointeeType();
if (elem == selem) {
auto dst = Visit(BCdst->getSubExpr()).getValue(builder);
auto src = Visit(BCsrc->getSubExpr()).getValue(builder);
bool dstArray = false;
Glob.getMLIRType(elem, &dstArray);
auto elemSize = getTypeSize(elem);
mlir::Value size = builder.create<mlir::IndexCastOp>(
loc, Visit(expr->getArg(2)).getValue(builder),
loc, Visit(expr->getArg( sr->getDecl()->getName() == "cudaMemcpy" ? 2 : 1)).getValue(builder),
mlir::IndexType::get(builder.getContext()));
size = builder.create<mlir::UnsignedDivIOp>(
loc, size,
@ -1777,7 +1819,7 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
builder.setInsertionPointToStart(&affineOp.getLoopBody().front());
if (isArray) {
if (dstArray) {
std::vector<mlir::Value> start = {getConstantIndex(0)};
auto mt =
Glob.getMLIRType(Glob.CGM.getContext().getPointerType(elem))
@ -1799,7 +1841,8 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
builder.setInsertionPoint(oldblock, oldpoint);
return ValueWithOffsets(getConstantIndex(0), /*isReference*/ false);
}
} } }
}
auto callee = EmitCallee(expr->getCallee());
@ -1822,6 +1865,7 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
}
std::set<std::string> funcs = {"strcmp",
"memcpy",
"open",
"fopen",
"memset",
@ -2225,6 +2269,20 @@ ValueWithOffsets MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) {
return ValueWithOffsets(builder.create<mlir::XOrOp>(loc, val, c1),
/*isReference*/ false);
}
case clang::UnaryOperator::Opcode::UO_Not: {
assert(sub.val);
mlir::Value val = sub.getValue(builder);
if (!val.getType().isa<mlir::IntegerType>()) {
U->dump();
val.dump();
}
auto ty = val.getType().cast<mlir::IntegerType>();
auto c1 = builder.create<mlir::ConstantOp>(loc, ty,
builder.getIntegerAttr(ty, APInt::getAllOnesValue(ty.getWidth())));
return ValueWithOffsets(builder.create<mlir::XOrOp>(loc, val, c1),
/*isReference*/ false);
}
case clang::UnaryOperator::Opcode::UO_Deref: {
auto dref = sub.dereference(builder);
return dref;
@ -2619,6 +2677,12 @@ ValueWithOffsets MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
rhs.getValue(builder)),
/*isReference*/ false);
}
case clang::BinaryOperator::Opcode::BO_Xor: {
return ValueWithOffsets(builder.create<mlir::XOrOp>(loc,
lhs.getValue(builder),
rhs.getValue(builder)),
/*isReference*/ false);
}
case clang::BinaryOperator::Opcode::BO_Or: {
// TODO short circuit
return ValueWithOffsets(builder.create<mlir::OrOp>(loc,
@ -2865,8 +2929,17 @@ ValueWithOffsets MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
auto prev = lhs.getValue(builder);
mlir::Value result;
if (prev.getType().isa<mlir::FloatType>()) {
result = builder.create<mlir::AddFOp>(loc, prev, rhs.getValue(builder));
if (auto postTy = prev.getType().dyn_cast<mlir::FloatType>()) {
mlir::Value rhsV = rhs.getValue(builder);
auto prevTy = rhsV.getType().cast<mlir::FloatType>();
if (prevTy == postTy) {}
else if (prevTy.getWidth() < postTy.getWidth()) {
rhsV = builder.create<mlir::FPExtOp>(loc, rhsV, postTy);
} else {
rhsV = builder.create<mlir::FPTruncOp>(loc, rhsV, postTy);
}
assert(rhsV.getType() == prev.getType());
result = builder.create<mlir::AddFOp>(loc, prev, rhsV);
} else if (auto pt =
prev.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
result = builder.create<LLVM::GEPOp>(
@ -3406,7 +3479,7 @@ ValueWithOffsets MLIRScanner::VisitCastExpr(CastExpr *E) {
auto mlirType = getMLIRType(E->getType());
return ValueWithOffsets(
builder.create<mlir::NVVM::WarpSizeOp>(loc, mlirType),
/*isReference*/ true);
/*isReference*/ false);
}
}
}
@ -3776,6 +3849,13 @@ ValueWithOffsets MLIRScanner::VisitReturnStmt(clang::ReturnStmt *stmt) {
builder.create<mlir::ReturnOp>(loc);
} else if (stmt->getRetValue()) {
auto rv = Visit(stmt->getRetValue());
if (stmt->getRetValue()->getType()->isVoidType()) {
builder.create<mlir::ReturnOp>(loc);
return nullptr;
}
if (!rv.val) {
stmt->dump();
}
assert(rv.val);
if (stmt->getRetValue()->isLValue() || stmt->getRetValue()->isXValue()) {
assert(rv.isReference);
@ -4110,7 +4190,7 @@ void MLIRASTConsumer::HandleDeclContext(DeclContext *DC) {
continue;
if (fd->getIdentifier() == nullptr)
continue;
if (emitIfFound.count(fd->getName().str())) {
if ((emitIfFound.count("*") && fd->getName() != "fpclassify" && !fd->isStatic())|| emitIfFound.count(fd->getName().str())) {
functionsToEmit.push_back(fd);
} else {
}
@ -4135,7 +4215,7 @@ bool MLIRASTConsumer::HandleTopLevelDecl(DeclGroupRef dg) {
continue;
if (fd->getIdentifier() == nullptr)
continue;
if (emitIfFound.count(fd->getName().str())) {
if ((emitIfFound.count("*") && fd->getName() != "fpclassify" && !fd->isStatic())|| emitIfFound.count(fd->getName().str())) {
functionsToEmit.push_back(fd);
} else {
}
@ -4220,7 +4300,7 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
}
auto CXRD = dyn_cast<CXXRecordDecl>(RT->getDecl());
if (RT->getDecl()->isUnion() || (CXRD && CXRD->getNumBases() > 0) ||
if (RT->getDecl()->isUnion() || (CXRD && (!CXRD->hasDefinition() || CXRD->getDefinition()->getNumBases() > 0)) ||
ST->getNumElements() == 0 || recursive ||
(!ST->isLiteral() && (ST->getName().contains("SmallVector") ||
ST->getName() == "struct._IO_FILE" ||
@ -4301,6 +4381,27 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
return mlir::MemRefType::get({size}, ET);
}
if (auto AT = dyn_cast<clang::VectorType>(t)) {
bool subRef = false;
auto ET = getMLIRType(AT->getElementType(), &subRef, allowMerge);
int64_t size = AT->getNumElements();
if (subRef) {
auto mt = ET.cast<MemRefType>();
auto shape2 = std::vector<int64_t>(mt.getShape());
shape2.insert(shape2.begin(), size);
if (implicitRef)
*implicitRef = true;
return mlir::MemRefType::get(shape2, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace());
}
if (!allowMerge || ET.isa<LLVM::LLVMPointerType, LLVM::LLVMArrayType,
LLVM::LLVMFunctionType, LLVM::LLVMStructType>())
return LLVM::LLVMFixedVectorType::get(ET, size);
if (implicitRef)
*implicitRef = true;
return mlir::MemRefType::get({size}, ET);
}
if (auto PT = dyn_cast<clang::PointerType>(t)) {
auto PTT = PT->getPointeeType()->getUnqualifiedDesugaredType();
@ -4329,6 +4430,18 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
return LLVM::LLVMPointerType::get(subType);
}
if (isa<clang::VectorType>(PTT)) {
if (subType.isa<MemRefType>()) {
assert(subRef);
auto mt = subType.cast<MemRefType>();
auto shape2 = std::vector<int64_t>(mt.getShape());
shape2.insert(shape2.begin(), -1);
return mlir::MemRefType::get(shape2, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace());
} else
return LLVM::LLVMPointerType::get(subType);
}
if (isa<clang::RecordType>(PTT))
if (subRef) {
auto mt = subType.cast<MemRefType>();
@ -4605,6 +4718,16 @@ static bool parseMLIR(const char *Argv0, std::vector<std::string> filenames,
chars[a.length()] = 0;
Argv.push_back(chars);
}
if (ResourceDir != "") {
Argv.push_back("-resource-dir");
char *chars = (char *)malloc(ResourceDir.length() + 1);
memcpy(chars, ResourceDir.data(), ResourceDir.length());
chars[ResourceDir.length()] = 0;
Argv.push_back(chars);
}
if (Verbose) {
Argv.push_back("-v");
}
if (CUDAGPUArch != "") {
auto a = "--cuda-gpu-arch=" + CUDAGPUArch;
char *chars = (char *)malloc(a.length() + 1);
@ -4734,4 +4857,4 @@ static bool parseMLIR(const char *Argv0, std::vector<std::string> filenames,
triple = Clang->getTarget().getTriple();
}
return true;
}
}

View File

@ -107,6 +107,9 @@ struct ValueWithOffsets {
assert(val.getType().isa<mlir::MemRefType>());
// return ValueWithOffsets(builder.create<memref::SubIndexOp>(loc, mt0, val,
// c0), /*isReference*/true);
if (val.getType().cast<mlir::MemRefType>().getShape().size() != 1) {
llvm::errs() << " val: " << val << " ty: " << val.getType() << "\n";
}
assert(val.getType().cast<mlir::MemRefType>().getShape().size() == 1);
return builder.create<memref::LoadOp>(loc, val,
std::vector<mlir::Value>({c0}));
@ -762,4 +765,4 @@ public:
ValueWithOffsets CommonArrayToPointer(ValueWithOffsets val);
};
#endif
#endif

View File

@ -32,6 +32,9 @@ static cl::opt<bool> ImmediateMLIR("immediate", cl::init(false),
static cl::opt<bool> RaiseToAffine("raise-scf-to-affine", cl::init(false),
cl::desc("Raise SCF to Affine"));
static cl::opt<bool> ScalarReplacement("scal-rep", cl::init(true),
cl::desc("Raise SCF to Affine"));
static cl::opt<bool>
DetectReduction("detect-reduction", cl::init(false),
cl::desc("Detect reduction in inner most loop"));
@ -61,6 +64,8 @@ static cl::opt<std::string> MArch("march", cl::init(""),
static cl::opt<std::string> ResourceDir("resource-dir", cl::init(""),
cl::desc("Resource-dir"));
static cl::opt<bool> Verbose("v", cl::init(false), cl::desc("Verbose"));
static cl::opt<bool>
showDialects("show-dialects",
llvm::cl::desc("Print the list of registered dialects"),
@ -151,6 +156,7 @@ int main(int argc, char **argv) {
optPM.addPass(polygeist::createLoopRestructurePass());
optPM.addPass(polygeist::replaceAffineCFGPass());
optPM.addPass(mlir::createCanonicalizerPass());
if (ScalarReplacement)
optPM.addPass(mlir::createAffineScalarReplacementPass());
optPM.addPass(mlir::createLoopInvariantCodeMotionPass());
optPM.addPass(mlir::createCanonicalizerPass());