Fix backprop
This commit is contained in:
parent
29997892ca
commit
10c886fe43
|
@ -13,6 +13,9 @@ include "Dialect.td"
|
|||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
|
||||
|
||||
// HasParent<"ParallelOp">,
|
||||
def BarrierOp
|
||||
: Polygeist_Op<"barrier",
|
||||
|
@ -37,4 +40,20 @@ def SubIndexOp : Polygeist_Op<"subindex", [
|
|||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Memref2PointerOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Memref2PointerOp : Polygeist_Op<"memref2pointer", [
|
||||
DeclareOpInterfaceMethods<ViewLikeOpInterface>, NoSideEffect
|
||||
]> {
|
||||
let summary = "memref subview operation";
|
||||
|
||||
let arguments = (ins AnyMemRef : $source);
|
||||
let results = (outs LLVM_AnyPointer : $result);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
#endif // POLYGEIST_OPS
|
||||
|
|
|
@ -10,6 +10,9 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "polygeist/Dialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "polygeist/PolygeistOps.cpp.inc"
|
||||
|
||||
|
@ -263,3 +266,24 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
.insert<SubIndexOpMemRefCastFolder, SubIndex2, SubToCast, DeallocSubView>(
|
||||
context);
|
||||
}
|
||||
|
||||
Value Memref2PointerOp::getViewSource() { return source(); }
|
||||
|
||||
class MemRef2PointerCast final : public OpRewritePattern<Memref2PointerOp> {
|
||||
public:
|
||||
using OpRewritePattern<Memref2PointerOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Memref2PointerOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto src = op.source().getDefiningOp<memref::CastOp>();
|
||||
if (!src)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<polygeist::Memref2PointerOp>(op, op.getType(), src.source());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
void Memref2PointerOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<MemRef2PointerCast>(context);
|
||||
}
|
||||
|
|
|
@ -1422,6 +1422,23 @@ struct SubToAdd : public OpRewritePattern<SubIOp> {
|
|||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReturnSq : public OpRewritePattern<ReturnOp> {
|
||||
using OpRewritePattern<ReturnOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ReturnOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
bool changed = false;
|
||||
SmallVector<Operation*> toErase;
|
||||
for (auto iter = op->getBlock()->rbegin(); iter != op->getBlock()->rend() && &*iter != op; iter++) {
|
||||
changed = true;
|
||||
toErase.push_back(&*iter);
|
||||
}
|
||||
for(auto op : toErase) {
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
return success(changed);
|
||||
}
|
||||
};
|
||||
void CanonicalizeFor::runOnFunction() {
|
||||
mlir::RewritePatternSet rpl(getFunction().getContext());
|
||||
rpl.add<PropagateInLoopBody, DetectTrivialIndVarInArgs,
|
||||
|
@ -1429,6 +1446,7 @@ void CanonicalizeFor::runOnFunction() {
|
|||
MoveWhileDown, MoveWhileDown2, MoveWhileDown3,
|
||||
MoveWhileInvariantIfResult, WhileLogicalNegation, SubToAdd,
|
||||
WhileCmpOffset, WhileLICM, RemoveUnusedCondVar,
|
||||
ReturnSq,
|
||||
MoveSideEffectFreeWhile>(getFunction().getContext());
|
||||
GreedyRewriteConfig config;
|
||||
config.maxIterations = 47;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "clang-mlir.h"
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <clang/AST/Decl.h>
|
||||
#include <clang/Basic/DiagnosticOptions.h>
|
||||
#include <clang/Basic/FileManager.h>
|
||||
#include <clang/Basic/FileSystemOptions.h>
|
||||
|
@ -91,7 +92,7 @@ ValueWithOffsets MLIRScanner::VisitDeclStmt(clang::DeclStmt *decl) {
|
|||
for (auto sub : decl->decls()) {
|
||||
if (auto vd = dyn_cast<VarDecl>(sub)) {
|
||||
VisitVarDecl(vd);
|
||||
} else if (isa<TypeAliasDecl, RecordDecl>(sub)) {
|
||||
} else if (isa<TypeAliasDecl, RecordDecl, StaticAssertDecl, TypedefDecl>(sub)) {
|
||||
} else {
|
||||
llvm::errs() << " + visiting unknonwn sub decl stmt\n";
|
||||
sub->dump();
|
||||
|
@ -1232,7 +1233,18 @@ MLIRScanner::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *expr) {
|
|||
// Check the RHS has been successfully emitted
|
||||
assert(rhs);
|
||||
auto idx = castToIndex(getMLIRLocation(expr->getRBracketLoc()), rhs);
|
||||
if (isa<clang::VectorType>(expr->getLHS()->getType()->getUnqualifiedDesugaredType())) {
|
||||
assert(moo.isReference);
|
||||
moo.isReference = false;
|
||||
auto mt = moo.val.getType().cast<MemRefType>();
|
||||
|
||||
auto shape = std::vector<int64_t>(mt.getShape());
|
||||
shape.erase(shape.begin());
|
||||
auto mt0 = mlir::MemRefType::get(shape, mt.getElementType(),
|
||||
mt.getAffineMaps(), mt.getMemorySpace());
|
||||
moo.val = builder.create<polygeist::SubIndexOp>(loc, mt0, moo.val,
|
||||
getConstantIndex(0));
|
||||
}
|
||||
return CommonArrayLookup(moo, idx);
|
||||
}
|
||||
|
||||
|
@ -1287,7 +1299,6 @@ llvm::Type *anonymize(llvm::Type *T) {
|
|||
|
||||
const clang::FunctionDecl *MLIRScanner::EmitCallee(const Expr *E) {
|
||||
E = E->IgnoreParens();
|
||||
|
||||
// Look through function-to-pointer decay.
|
||||
if (auto ICE = dyn_cast<ImplicitCastExpr>(E)) {
|
||||
if (ICE->getCastKind() == CK_FunctionToPointerDecay ||
|
||||
|
@ -1300,6 +1311,7 @@ const clang::FunctionDecl *MLIRScanner::EmitCallee(const Expr *E) {
|
|||
if (auto FD = dyn_cast<FunctionDecl>(DRE->getDecl())) {
|
||||
return FD;
|
||||
}
|
||||
|
||||
} else if (auto ME = dyn_cast<MemberExpr>(E)) {
|
||||
if (auto FD = dyn_cast<FunctionDecl>(ME->getMemberDecl())) {
|
||||
// TODO EmitIgnoredExpr(ME->getBase());
|
||||
|
@ -1454,14 +1466,21 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
|
||||
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
sr->getDecl()->getName() == "__shfl_up_sync") {
|
||||
std::vector<mlir::Value> args;
|
||||
for (auto a : expr->arguments()) {
|
||||
args.push_back(Visit(a).getValue(builder));
|
||||
}
|
||||
builder.create<gpu::ShuffleOp>(loc, );
|
||||
assert(0 && "__shfl_up_sync unhandled");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
|
||||
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
|
||||
|
@ -1476,6 +1495,19 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
/*isReference*/ false);
|
||||
}
|
||||
}
|
||||
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
|
||||
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
sr->getDecl()->getName() == "log") {
|
||||
std::vector<mlir::Value> args;
|
||||
for (auto a : expr->arguments()) {
|
||||
args.push_back(Visit(a).getValue(builder));
|
||||
}
|
||||
return ValueWithOffsets(
|
||||
builder.create<mlir::math::LogOp>(loc, args[0]),
|
||||
/*isReference*/ false);
|
||||
}
|
||||
}
|
||||
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
|
||||
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
|
@ -1607,7 +1639,11 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
assert(sub.isReference);
|
||||
return sub.val;
|
||||
}
|
||||
return sub.getValue(builder);
|
||||
auto val = sub.getValue(builder);
|
||||
if (auto mt = val.getType().dyn_cast<MemRefType>()) {
|
||||
val = builder.create<polygeist::Memref2PointerOp>(loc, LLVM::LLVMPointerType::get(mt.getElementType()), val);
|
||||
}
|
||||
return val;
|
||||
};
|
||||
|
||||
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
|
||||
|
@ -1743,22 +1779,28 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
|
||||
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr()))
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
(sr->getDecl()->getName() == "cudaMemcpy"))
|
||||
if (auto BCdst = dyn_cast<clang::CastExpr>(expr->getArg(0)))
|
||||
if (auto BCsrc = dyn_cast<clang::CastExpr>(expr->getArg(1))) {
|
||||
auto dst = Visit(BCdst->getSubExpr()).getValue(builder);
|
||||
auto src = Visit(BCsrc->getSubExpr()).getValue(builder);
|
||||
|
||||
(sr->getDecl()->getName() == "cudaMemcpy" || sr->getDecl()->getName() == "memcpy")) {
|
||||
if (auto BCdst = dyn_cast<clang::CastExpr>(expr->getArg(0))) {
|
||||
auto elem =
|
||||
cast<clang::PointerType>(BCdst->getSubExpr()
|
||||
->getType()
|
||||
->getUnqualifiedDesugaredType())
|
||||
->getPointeeType();
|
||||
bool isArray = false;
|
||||
Glob.getMLIRType(elem, &isArray);
|
||||
if (auto BCsrc = dyn_cast<clang::CastExpr>(expr->getArg(1))) {
|
||||
auto selem =
|
||||
cast<clang::PointerType>(BCsrc->getSubExpr()
|
||||
->getType()
|
||||
->getUnqualifiedDesugaredType())
|
||||
->getPointeeType();
|
||||
if (elem == selem) {
|
||||
auto dst = Visit(BCdst->getSubExpr()).getValue(builder);
|
||||
auto src = Visit(BCsrc->getSubExpr()).getValue(builder);
|
||||
|
||||
bool dstArray = false;
|
||||
Glob.getMLIRType(elem, &dstArray);
|
||||
auto elemSize = getTypeSize(elem);
|
||||
mlir::Value size = builder.create<mlir::IndexCastOp>(
|
||||
loc, Visit(expr->getArg(2)).getValue(builder),
|
||||
loc, Visit(expr->getArg( sr->getDecl()->getName() == "cudaMemcpy" ? 2 : 1)).getValue(builder),
|
||||
mlir::IndexType::get(builder.getContext()));
|
||||
size = builder.create<mlir::UnsignedDivIOp>(
|
||||
loc, size,
|
||||
|
@ -1777,7 +1819,7 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
|
||||
builder.setInsertionPointToStart(&affineOp.getLoopBody().front());
|
||||
|
||||
if (isArray) {
|
||||
if (dstArray) {
|
||||
std::vector<mlir::Value> start = {getConstantIndex(0)};
|
||||
auto mt =
|
||||
Glob.getMLIRType(Glob.CGM.getContext().getPointerType(elem))
|
||||
|
@ -1799,6 +1841,7 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
builder.setInsertionPoint(oldblock, oldpoint);
|
||||
|
||||
return ValueWithOffsets(getConstantIndex(0), /*isReference*/ false);
|
||||
} } }
|
||||
}
|
||||
|
||||
auto callee = EmitCallee(expr->getCallee());
|
||||
|
@ -1822,6 +1865,7 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
}
|
||||
|
||||
std::set<std::string> funcs = {"strcmp",
|
||||
"memcpy",
|
||||
"open",
|
||||
"fopen",
|
||||
"memset",
|
||||
|
@ -2225,6 +2269,20 @@ ValueWithOffsets MLIRScanner::VisitUnaryOperator(clang::UnaryOperator *U) {
|
|||
return ValueWithOffsets(builder.create<mlir::XOrOp>(loc, val, c1),
|
||||
/*isReference*/ false);
|
||||
}
|
||||
case clang::UnaryOperator::Opcode::UO_Not: {
|
||||
assert(sub.val);
|
||||
mlir::Value val = sub.getValue(builder);
|
||||
|
||||
if (!val.getType().isa<mlir::IntegerType>()) {
|
||||
U->dump();
|
||||
val.dump();
|
||||
}
|
||||
auto ty = val.getType().cast<mlir::IntegerType>();
|
||||
auto c1 = builder.create<mlir::ConstantOp>(loc, ty,
|
||||
builder.getIntegerAttr(ty, APInt::getAllOnesValue(ty.getWidth())));
|
||||
return ValueWithOffsets(builder.create<mlir::XOrOp>(loc, val, c1),
|
||||
/*isReference*/ false);
|
||||
}
|
||||
case clang::UnaryOperator::Opcode::UO_Deref: {
|
||||
auto dref = sub.dereference(builder);
|
||||
return dref;
|
||||
|
@ -2619,6 +2677,12 @@ ValueWithOffsets MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
|
|||
rhs.getValue(builder)),
|
||||
/*isReference*/ false);
|
||||
}
|
||||
case clang::BinaryOperator::Opcode::BO_Xor: {
|
||||
return ValueWithOffsets(builder.create<mlir::XOrOp>(loc,
|
||||
lhs.getValue(builder),
|
||||
rhs.getValue(builder)),
|
||||
/*isReference*/ false);
|
||||
}
|
||||
case clang::BinaryOperator::Opcode::BO_Or: {
|
||||
// TODO short circuit
|
||||
return ValueWithOffsets(builder.create<mlir::OrOp>(loc,
|
||||
|
@ -2865,8 +2929,17 @@ ValueWithOffsets MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
|
|||
auto prev = lhs.getValue(builder);
|
||||
|
||||
mlir::Value result;
|
||||
if (prev.getType().isa<mlir::FloatType>()) {
|
||||
result = builder.create<mlir::AddFOp>(loc, prev, rhs.getValue(builder));
|
||||
if (auto postTy = prev.getType().dyn_cast<mlir::FloatType>()) {
|
||||
mlir::Value rhsV = rhs.getValue(builder);
|
||||
auto prevTy = rhsV.getType().cast<mlir::FloatType>();
|
||||
if (prevTy == postTy) {}
|
||||
else if (prevTy.getWidth() < postTy.getWidth()) {
|
||||
rhsV = builder.create<mlir::FPExtOp>(loc, rhsV, postTy);
|
||||
} else {
|
||||
rhsV = builder.create<mlir::FPTruncOp>(loc, rhsV, postTy);
|
||||
}
|
||||
assert(rhsV.getType() == prev.getType());
|
||||
result = builder.create<mlir::AddFOp>(loc, prev, rhsV);
|
||||
} else if (auto pt =
|
||||
prev.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
|
||||
result = builder.create<LLVM::GEPOp>(
|
||||
|
@ -3406,7 +3479,7 @@ ValueWithOffsets MLIRScanner::VisitCastExpr(CastExpr *E) {
|
|||
auto mlirType = getMLIRType(E->getType());
|
||||
return ValueWithOffsets(
|
||||
builder.create<mlir::NVVM::WarpSizeOp>(loc, mlirType),
|
||||
/*isReference*/ true);
|
||||
/*isReference*/ false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3776,6 +3849,13 @@ ValueWithOffsets MLIRScanner::VisitReturnStmt(clang::ReturnStmt *stmt) {
|
|||
builder.create<mlir::ReturnOp>(loc);
|
||||
} else if (stmt->getRetValue()) {
|
||||
auto rv = Visit(stmt->getRetValue());
|
||||
if (stmt->getRetValue()->getType()->isVoidType()) {
|
||||
builder.create<mlir::ReturnOp>(loc);
|
||||
return nullptr;
|
||||
}
|
||||
if (!rv.val) {
|
||||
stmt->dump();
|
||||
}
|
||||
assert(rv.val);
|
||||
if (stmt->getRetValue()->isLValue() || stmt->getRetValue()->isXValue()) {
|
||||
assert(rv.isReference);
|
||||
|
@ -4110,7 +4190,7 @@ void MLIRASTConsumer::HandleDeclContext(DeclContext *DC) {
|
|||
continue;
|
||||
if (fd->getIdentifier() == nullptr)
|
||||
continue;
|
||||
if (emitIfFound.count(fd->getName().str())) {
|
||||
if ((emitIfFound.count("*") && fd->getName() != "fpclassify" && !fd->isStatic())|| emitIfFound.count(fd->getName().str())) {
|
||||
functionsToEmit.push_back(fd);
|
||||
} else {
|
||||
}
|
||||
|
@ -4135,7 +4215,7 @@ bool MLIRASTConsumer::HandleTopLevelDecl(DeclGroupRef dg) {
|
|||
continue;
|
||||
if (fd->getIdentifier() == nullptr)
|
||||
continue;
|
||||
if (emitIfFound.count(fd->getName().str())) {
|
||||
if ((emitIfFound.count("*") && fd->getName() != "fpclassify" && !fd->isStatic())|| emitIfFound.count(fd->getName().str())) {
|
||||
functionsToEmit.push_back(fd);
|
||||
} else {
|
||||
}
|
||||
|
@ -4220,7 +4300,7 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
|
|||
}
|
||||
|
||||
auto CXRD = dyn_cast<CXXRecordDecl>(RT->getDecl());
|
||||
if (RT->getDecl()->isUnion() || (CXRD && CXRD->getNumBases() > 0) ||
|
||||
if (RT->getDecl()->isUnion() || (CXRD && (!CXRD->hasDefinition() || CXRD->getDefinition()->getNumBases() > 0)) ||
|
||||
ST->getNumElements() == 0 || recursive ||
|
||||
(!ST->isLiteral() && (ST->getName().contains("SmallVector") ||
|
||||
ST->getName() == "struct._IO_FILE" ||
|
||||
|
@ -4301,6 +4381,27 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
|
|||
return mlir::MemRefType::get({size}, ET);
|
||||
}
|
||||
|
||||
if (auto AT = dyn_cast<clang::VectorType>(t)) {
|
||||
bool subRef = false;
|
||||
auto ET = getMLIRType(AT->getElementType(), &subRef, allowMerge);
|
||||
int64_t size = AT->getNumElements();
|
||||
if (subRef) {
|
||||
auto mt = ET.cast<MemRefType>();
|
||||
auto shape2 = std::vector<int64_t>(mt.getShape());
|
||||
shape2.insert(shape2.begin(), size);
|
||||
if (implicitRef)
|
||||
*implicitRef = true;
|
||||
return mlir::MemRefType::get(shape2, mt.getElementType(),
|
||||
mt.getAffineMaps(), mt.getMemorySpace());
|
||||
}
|
||||
if (!allowMerge || ET.isa<LLVM::LLVMPointerType, LLVM::LLVMArrayType,
|
||||
LLVM::LLVMFunctionType, LLVM::LLVMStructType>())
|
||||
return LLVM::LLVMFixedVectorType::get(ET, size);
|
||||
if (implicitRef)
|
||||
*implicitRef = true;
|
||||
return mlir::MemRefType::get({size}, ET);
|
||||
}
|
||||
|
||||
if (auto PT = dyn_cast<clang::PointerType>(t)) {
|
||||
auto PTT = PT->getPointeeType()->getUnqualifiedDesugaredType();
|
||||
|
||||
|
@ -4329,6 +4430,18 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
|
|||
return LLVM::LLVMPointerType::get(subType);
|
||||
}
|
||||
|
||||
if (isa<clang::VectorType>(PTT)) {
|
||||
if (subType.isa<MemRefType>()) {
|
||||
assert(subRef);
|
||||
auto mt = subType.cast<MemRefType>();
|
||||
auto shape2 = std::vector<int64_t>(mt.getShape());
|
||||
shape2.insert(shape2.begin(), -1);
|
||||
return mlir::MemRefType::get(shape2, mt.getElementType(),
|
||||
mt.getAffineMaps(), mt.getMemorySpace());
|
||||
} else
|
||||
return LLVM::LLVMPointerType::get(subType);
|
||||
}
|
||||
|
||||
if (isa<clang::RecordType>(PTT))
|
||||
if (subRef) {
|
||||
auto mt = subType.cast<MemRefType>();
|
||||
|
@ -4605,6 +4718,16 @@ static bool parseMLIR(const char *Argv0, std::vector<std::string> filenames,
|
|||
chars[a.length()] = 0;
|
||||
Argv.push_back(chars);
|
||||
}
|
||||
if (ResourceDir != "") {
|
||||
Argv.push_back("-resource-dir");
|
||||
char *chars = (char *)malloc(ResourceDir.length() + 1);
|
||||
memcpy(chars, ResourceDir.data(), ResourceDir.length());
|
||||
chars[ResourceDir.length()] = 0;
|
||||
Argv.push_back(chars);
|
||||
}
|
||||
if (Verbose) {
|
||||
Argv.push_back("-v");
|
||||
}
|
||||
if (CUDAGPUArch != "") {
|
||||
auto a = "--cuda-gpu-arch=" + CUDAGPUArch;
|
||||
char *chars = (char *)malloc(a.length() + 1);
|
||||
|
|
|
@ -107,6 +107,9 @@ struct ValueWithOffsets {
|
|||
assert(val.getType().isa<mlir::MemRefType>());
|
||||
// return ValueWithOffsets(builder.create<memref::SubIndexOp>(loc, mt0, val,
|
||||
// c0), /*isReference*/true);
|
||||
if (val.getType().cast<mlir::MemRefType>().getShape().size() != 1) {
|
||||
llvm::errs() << " val: " << val << " ty: " << val.getType() << "\n";
|
||||
}
|
||||
assert(val.getType().cast<mlir::MemRefType>().getShape().size() == 1);
|
||||
return builder.create<memref::LoadOp>(loc, val,
|
||||
std::vector<mlir::Value>({c0}));
|
||||
|
|
|
@ -32,6 +32,9 @@ static cl::opt<bool> ImmediateMLIR("immediate", cl::init(false),
|
|||
static cl::opt<bool> RaiseToAffine("raise-scf-to-affine", cl::init(false),
|
||||
cl::desc("Raise SCF to Affine"));
|
||||
|
||||
static cl::opt<bool> ScalarReplacement("scal-rep", cl::init(true),
|
||||
cl::desc("Raise SCF to Affine"));
|
||||
|
||||
static cl::opt<bool>
|
||||
DetectReduction("detect-reduction", cl::init(false),
|
||||
cl::desc("Detect reduction in inner most loop"));
|
||||
|
@ -61,6 +64,8 @@ static cl::opt<std::string> MArch("march", cl::init(""),
|
|||
static cl::opt<std::string> ResourceDir("resource-dir", cl::init(""),
|
||||
cl::desc("Resource-dir"));
|
||||
|
||||
static cl::opt<bool> Verbose("v", cl::init(false), cl::desc("Verbose"));
|
||||
|
||||
static cl::opt<bool>
|
||||
showDialects("show-dialects",
|
||||
llvm::cl::desc("Print the list of registered dialects"),
|
||||
|
@ -151,6 +156,7 @@ int main(int argc, char **argv) {
|
|||
optPM.addPass(polygeist::createLoopRestructurePass());
|
||||
optPM.addPass(polygeist::replaceAffineCFGPass());
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
if (ScalarReplacement)
|
||||
optPM.addPass(mlir::createAffineScalarReplacementPass());
|
||||
optPM.addPass(mlir::createLoopInvariantCodeMotionPass());
|
||||
optPM.addPass(mlir::createCanonicalizerPass());
|
||||
|
|
Loading…
Reference in New Issue