Bump LLVM
This commit is contained in:
parent
a311f76736
commit
41966050a6
|
@ -291,7 +291,7 @@ public:
|
|||
auto constIdx = op.index().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constIdx)
|
||||
return failure();
|
||||
auto constValue = constIdx.value().dyn_cast<IntegerAttr>();
|
||||
auto constValue = constIdx.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!constValue || !constValue.getType().isa<IndexType>() ||
|
||||
constValue.getValue().getZExtValue() != 0)
|
||||
return failure();
|
||||
|
@ -672,7 +672,7 @@ struct SelectOfCast : public OpRewritePattern<SelectOp> {
|
|||
if (cst1.source().getType() != cst2.source().getType())
|
||||
return failure();
|
||||
|
||||
auto newSel = rewriter.create<SelectOp>(op.getLoc(), op.condition(),
|
||||
auto newSel = rewriter.create<SelectOp>(op.getLoc(), op.getCondition(),
|
||||
cst1.source(), cst2.source());
|
||||
|
||||
rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(), newSel);
|
||||
|
@ -697,9 +697,9 @@ struct SelectOfSubIndex : public OpRewritePattern<SelectOp> {
|
|||
if (cst1.source().getType() != cst2.source().getType())
|
||||
return failure();
|
||||
|
||||
auto newSel = rewriter.create<SelectOp>(op.getLoc(), op.condition(),
|
||||
auto newSel = rewriter.create<SelectOp>(op.getLoc(), op.getCondition(),
|
||||
cst1.source(), cst2.source());
|
||||
auto newIdx = rewriter.create<SelectOp>(op.getLoc(), op.condition(),
|
||||
auto newIdx = rewriter.create<SelectOp>(op.getLoc(), op.getCondition(),
|
||||
cst1.index(), cst2.index());
|
||||
rewriter.replaceOpWithNewOp<SubIndexOp>(op, op.getType(), newSel, newIdx);
|
||||
return success();
|
||||
|
@ -874,7 +874,8 @@ public:
|
|||
if (!src)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Pointer2MemrefOp>(op, op.getType(), src.arg());
|
||||
rewriter.replaceOpWithNewOp<Pointer2MemrefOp>(op, op.getType(),
|
||||
src.getArg());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -714,17 +714,17 @@ bool handle(OpBuilder &b, CmpIOp cmpi, SmallVectorImpl<AffineExpr> &exprs,
|
|||
SmallVectorImpl<bool> &eqflags, SmallVectorImpl<Value> &applies) {
|
||||
AffineMap lhsmap =
|
||||
AffineMap::get(0, 1, getAffineSymbolExpr(0, cmpi.getContext()));
|
||||
if (!isValidIndex(cmpi.lhs())) {
|
||||
if (!isValidIndex(cmpi.getLhs())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "illegal lhs: " << cmpi.lhs() << " - " << cmpi << "\n");
|
||||
<< "illegal lhs: " << cmpi.getLhs() << " - " << cmpi << "\n");
|
||||
return false;
|
||||
}
|
||||
if (!isValidIndex(cmpi.rhs())) {
|
||||
if (!isValidIndex(cmpi.getRhs())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "illegal rhs: " << cmpi.rhs() << " - " << cmpi << "\n");
|
||||
<< "illegal rhs: " << cmpi.getRhs() << " - " << cmpi << "\n");
|
||||
return false;
|
||||
}
|
||||
SmallVector<Value, 4> lhspack = {cmpi.lhs()};
|
||||
SmallVector<Value, 4> lhspack = {cmpi.getLhs()};
|
||||
if (!lhspack[0].getType().isa<IndexType>()) {
|
||||
auto op = b.create<IndexCastOp>(cmpi.getLoc(), lhspack[0],
|
||||
IndexType::get(cmpi.getContext()));
|
||||
|
@ -733,7 +733,7 @@ bool handle(OpBuilder &b, CmpIOp cmpi, SmallVectorImpl<AffineExpr> &exprs,
|
|||
|
||||
AffineMap rhsmap =
|
||||
AffineMap::get(0, 1, getAffineSymbolExpr(0, cmpi.getContext()));
|
||||
SmallVector<Value, 4> rhspack = {cmpi.rhs()};
|
||||
SmallVector<Value, 4> rhspack = {cmpi.getRhs()};
|
||||
if (!rhspack[0].getType().isa<IndexType>()) {
|
||||
auto op = b.create<IndexCastOp>(cmpi.getLoc(), rhspack[0],
|
||||
IndexType::get(cmpi.getContext()));
|
||||
|
|
|
@ -247,7 +247,7 @@ replicateIntoRegion(Region ®ion, Value storage, ValueRange ivs,
|
|||
if (auto branch = dyn_cast<BranchOp>(&op)) {
|
||||
// if (!blocks.contains(branch.dest())) {
|
||||
if (isa_and_nonnull<polygeist::BarrierOp>(branch->getPrevNode())) {
|
||||
auto it = llvm::find(subgraphEntryPoints, branch.dest());
|
||||
auto it = llvm::find(subgraphEntryPoints, branch.getDest());
|
||||
assert(it != subgraphEntryPoints.end());
|
||||
emitStoreContinuationID(
|
||||
branch.getLoc(), std::distance(subgraphEntryPoints.begin(), it),
|
||||
|
|
|
@ -557,7 +557,7 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
|||
return failure();
|
||||
}
|
||||
|
||||
BlockArgument indVar = cmpIOp.lhs().dyn_cast<BlockArgument>();
|
||||
BlockArgument indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
|
||||
if (!indVar)
|
||||
return failure();
|
||||
if (indVar.getOwner() != &loop.before().front())
|
||||
|
@ -608,11 +608,11 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
|||
loop.getLoc(), loop.getOperand(indVar.getArgNumber()),
|
||||
rewriter.create<ConstantIntOp>(loop.getLoc(), 1, indVar.getType()));
|
||||
|
||||
if (isBlockArg(cmpIOp.rhs()) || dominateWhile(cmpIOp.rhs(), loop)) {
|
||||
if (isBlockArg(cmpIOp.getRhs()) || dominateWhile(cmpIOp.getRhs(), loop)) {
|
||||
switch (cmpIOp.getPredicate()) {
|
||||
case CmpIPredicate::slt:
|
||||
case CmpIPredicate::ult: {
|
||||
loopInfo.ub = cmpIOp.rhs();
|
||||
loopInfo.ub = cmpIOp.getRhs();
|
||||
break;
|
||||
}
|
||||
case CmpIPredicate::ule:
|
||||
|
@ -620,13 +620,14 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
|||
// TODO: f32 likely not always true.
|
||||
auto one =
|
||||
rewriter.create<ConstantIntOp>(loop.getLoc(), 1, indVar.getType());
|
||||
auto addIOp = rewriter.create<AddIOp>(loop.getLoc(), cmpIOp.rhs(), one);
|
||||
auto addIOp =
|
||||
rewriter.create<AddIOp>(loop.getLoc(), cmpIOp.getRhs(), one);
|
||||
loopInfo.ub = addIOp.getResult();
|
||||
break;
|
||||
}
|
||||
case CmpIPredicate::uge:
|
||||
case CmpIPredicate::sge: {
|
||||
loopInfo.lb = cmpIOp.rhs();
|
||||
loopInfo.lb = cmpIOp.getRhs();
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -635,7 +636,8 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
|||
// TODO: f32 likely not always true.
|
||||
auto one =
|
||||
rewriter.create<ConstantIntOp>(loop.getLoc(), 1, indVar.getType());
|
||||
auto addIOp = rewriter.create<AddIOp>(loop.getLoc(), cmpIOp.rhs(), one);
|
||||
auto addIOp =
|
||||
rewriter.create<AddIOp>(loop.getLoc(), cmpIOp.getRhs(), one);
|
||||
loopInfo.lb = addIOp.getResult();
|
||||
break;
|
||||
}
|
||||
|
@ -647,12 +649,12 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
|||
} else {
|
||||
if (negativeStep)
|
||||
return failure();
|
||||
auto *op = cmpIOp.rhs().getDefiningOp();
|
||||
auto *op = cmpIOp.getRhs().getDefiningOp();
|
||||
if (!op || !canMoveOpOutsideWhile(op, loop) || (op->getNumResults() != 1))
|
||||
return failure();
|
||||
auto newOp = rewriter.clone(*op);
|
||||
loopInfo.ub = newOp->getResult(0);
|
||||
cmpIOp.rhs().replaceAllUsesWith(newOp->getResult(0));
|
||||
cmpIOp.getRhs().replaceAllUsesWith(newOp->getResult(0));
|
||||
}
|
||||
|
||||
if ((!loopInfo.ub) || (!loopInfo.lb) || (!step))
|
||||
|
@ -683,7 +685,7 @@ struct MoveWhileToFor : public OpRewritePattern<WhileOp> {
|
|||
Type cst = nullptr;
|
||||
if (auto idx = arg.getDefiningOp<IndexCastOp>()) {
|
||||
cst = idx.getType();
|
||||
arg = idx.in();
|
||||
arg = idx.getIn();
|
||||
}
|
||||
Value res;
|
||||
if (isTopLevelArgValue(arg, &loop.before())) {
|
||||
|
@ -1048,13 +1050,12 @@ struct WhileLogicalNegation : public OpRewritePattern<WhileOp> {
|
|||
for (auto pair : llvm::zip(op.getResults(), term.args(), origAfterArgs)) {
|
||||
if (!std::get<0>(pair).use_empty()) {
|
||||
if (auto termCmp = std::get<1>(pair).getDefiningOp<CmpIOp>()) {
|
||||
if (termCmp.lhs() == condCmp.lhs() &&
|
||||
termCmp.rhs() == condCmp.rhs()) {
|
||||
if (termCmp.getLhs() == condCmp.getLhs() &&
|
||||
termCmp.getRhs() == condCmp.getRhs()) {
|
||||
// TODO generalize to logical negation of
|
||||
if (condCmp.getPredicate() == CmpIPredicate::slt &&
|
||||
termCmp.getPredicate() == CmpIPredicate::sge) {
|
||||
|
||||
auto i1Ty = rewriter.getIntegerType(1);
|
||||
rewriter.updateRootInPlace(op, [&] {
|
||||
rewriter.setInsertionPoint(op);
|
||||
auto truev =
|
||||
|
@ -1086,7 +1087,7 @@ struct WhileCmpOffset : public OpRewritePattern<WhileOp> {
|
|||
assert(origAfterArgs.size() == term.args().size());
|
||||
|
||||
if (auto condCmp = term.condition().getDefiningOp<CmpIOp>()) {
|
||||
if (auto addI = condCmp.lhs().getDefiningOp<AddIOp>()) {
|
||||
if (auto addI = condCmp.getLhs().getDefiningOp<AddIOp>()) {
|
||||
if (addI.getOperand(1).getDefiningOp() &&
|
||||
!op.before().isAncestor(
|
||||
addI.getOperand(1).getDefiningOp()->getParentRegion()))
|
||||
|
@ -1368,8 +1369,8 @@ struct MoveSideEffectFreeWhile : public OpRewritePattern<WhileOp> {
|
|||
auto rep =
|
||||
op.after().front().addArgument(IC->getOperand(0).getType());
|
||||
IC->moveBefore(&op.after().front(), op.after().front().begin());
|
||||
conds.push_back(IC.in());
|
||||
IC.inMutable().assign(rep);
|
||||
conds.push_back(IC.getIn());
|
||||
IC.getInMutable().assign(rep);
|
||||
op.after().front().getArgument(i).replaceAllUsesWith(
|
||||
IC->getResult(0));
|
||||
changed = true;
|
||||
|
|
|
@ -297,7 +297,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
}
|
||||
}
|
||||
if (auto storeOp = dyn_cast<LLVM::StoreOp>(user)) {
|
||||
if (storeOp.addr() == val) {
|
||||
if (storeOp.getAddr() == val) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Matching Store: " << storeOp << "\n");
|
||||
allStoreOps.insert(storeOp);
|
||||
continue;
|
||||
|
@ -319,14 +319,14 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
}
|
||||
}
|
||||
if (auto callOp = dyn_cast<mlir::CallOp>(user)) {
|
||||
if (callOp.callee() != "free") {
|
||||
if (callOp.getCallee() != "free") {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Aliasing Store: " << callOp << "\n");
|
||||
AliasingStoreOperations.insert(callOp);
|
||||
captured = true;
|
||||
}
|
||||
}
|
||||
if (auto callOp = dyn_cast<mlir::LLVM::CallOp>(user)) {
|
||||
if (*callOp.callee() != "free") {
|
||||
if (*callOp.getCallee() != "free") {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Aliasing Store: " << callOp << "\n");
|
||||
AliasingStoreOperations.insert(callOp);
|
||||
captured = true;
|
||||
|
@ -722,7 +722,7 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
}
|
||||
} else if (auto storeOp = dyn_cast<LLVM::StoreOp>(a)) {
|
||||
if (allStoreOps.count(storeOp)) {
|
||||
lastVal = storeOp.value();
|
||||
lastVal = storeOp.getValue();
|
||||
seenSubStore = false;
|
||||
}
|
||||
} else if (auto storeOp = dyn_cast<AffineStoreOp>(a)) {
|
||||
|
@ -975,15 +975,15 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
op.erase();
|
||||
} else if (auto op = dyn_cast<SwitchOp>(pred->getTerminator())) {
|
||||
mlir::OpBuilder builder(op.getOperation());
|
||||
SmallVector<Value> defaultOps(op.defaultOperands().begin(),
|
||||
op.defaultOperands().end());
|
||||
SmallVector<Value> defaultOps(op.getDefaultOperands().begin(),
|
||||
op.getDefaultOperands().end());
|
||||
|
||||
if (op.defaultDestination() == block)
|
||||
if (op.getDefaultDestination() == block)
|
||||
defaultOps.push_back(pval);
|
||||
|
||||
SmallVector<SmallVector<Value>> cases;
|
||||
SmallVector<ValueRange> vrange;
|
||||
for (auto pair : llvm::enumerate(op.caseDestinations())) {
|
||||
for (auto pair : llvm::enumerate(op.getCaseDestinations())) {
|
||||
cases.emplace_back(op.getCaseOperands(pair.index()).begin(),
|
||||
op.getCaseOperands(pair.index()).end());
|
||||
if (pair.value() == block) {
|
||||
|
@ -992,8 +992,8 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
vrange.push_back(cases.back());
|
||||
}
|
||||
builder.create<mlir::SwitchOp>(
|
||||
op.getLoc(), op.flag(), op.defaultDestination(), defaultOps,
|
||||
op.case_valuesAttr(), op.caseDestinations(), vrange);
|
||||
op.getLoc(), op.getFlag(), op.getDefaultDestination(), defaultOps,
|
||||
op.getCaseValuesAttr(), op.getCaseDestinations(), vrange);
|
||||
op.erase();
|
||||
} else {
|
||||
llvm_unreachable("unknown pred branch");
|
||||
|
@ -1064,12 +1064,12 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
}
|
||||
} else if (auto op = dyn_cast<SwitchOp>(pred->getTerminator())) {
|
||||
mlir::OpBuilder subbuilder(op.getOperation());
|
||||
if (op.defaultDestination() == block) {
|
||||
pval = op.defaultOperands()[blockArg.getArgNumber()];
|
||||
if (op.getDefaultDestination() == block) {
|
||||
pval = op.getDefaultOperands()[blockArg.getArgNumber()];
|
||||
if (pval == blockArg)
|
||||
pval = nullptr;
|
||||
}
|
||||
for (auto pair : llvm::enumerate(op.caseDestinations())) {
|
||||
for (auto pair : llvm::enumerate(op.getCaseDestinations())) {
|
||||
if (pair.value() == block) {
|
||||
auto pval2 =
|
||||
op.getCaseOperands(pair.index())[blockArg.getArgNumber()];
|
||||
|
@ -1210,14 +1210,14 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
op.erase();
|
||||
} else if (auto op = dyn_cast<SwitchOp>(pred->getTerminator())) {
|
||||
mlir::OpBuilder builder(op.getOperation());
|
||||
SmallVector<Value> defaultOps(op.defaultOperands().begin(),
|
||||
op.defaultOperands().end());
|
||||
if (op.defaultDestination() == block)
|
||||
SmallVector<Value> defaultOps(op.getDefaultOperands().begin(),
|
||||
op.getDefaultOperands().end());
|
||||
if (op.getDefaultDestination() == block)
|
||||
defaultOps.erase(defaultOps.begin() + blockArg.getArgNumber());
|
||||
|
||||
SmallVector<SmallVector<Value>> cases;
|
||||
SmallVector<ValueRange> vrange;
|
||||
for (auto pair : llvm::enumerate(op.caseDestinations())) {
|
||||
for (auto pair : llvm::enumerate(op.getCaseDestinations())) {
|
||||
cases.emplace_back(op.getCaseOperands(pair.index()).begin(),
|
||||
op.getCaseOperands(pair.index()).end());
|
||||
if (pair.value() == block) {
|
||||
|
@ -1226,9 +1226,10 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector<ssize_t> idx,
|
|||
}
|
||||
vrange.push_back(cases.back());
|
||||
}
|
||||
builder.create<mlir::SwitchOp>(
|
||||
op.getLoc(), op.flag(), op.defaultDestination(), defaultOps,
|
||||
op.case_valuesAttr(), op.caseDestinations(), vrange);
|
||||
builder.create<mlir::SwitchOp>(op.getLoc(), op.getFlag(),
|
||||
op.getDefaultDestination(),
|
||||
defaultOps, op.getCaseValuesAttr(),
|
||||
op.getCaseDestinations(), vrange);
|
||||
op.erase();
|
||||
}
|
||||
}
|
||||
|
@ -1295,7 +1296,7 @@ bool isPromotable(mlir::Value AI) {
|
|||
continue;
|
||||
} else if (isa<memref::DeallocOp>(U)) {
|
||||
continue;
|
||||
} else if (isa<CallOp>(U) && cast<CallOp>(U).callee() == "free") {
|
||||
} else if (isa<CallOp>(U) && cast<CallOp>(U).getCallee() == "free") {
|
||||
continue;
|
||||
} else if (isa<CallOp>(U)) {
|
||||
// TODO check "no capture", currently assume as a fallback always
|
||||
|
@ -1466,7 +1467,7 @@ void Mem2Reg::runOnFunction() {
|
|||
if (isa<LLVM::StoreOp, memref::StoreOp, AffineStoreOp,
|
||||
memref::DeallocOp>(U)) {
|
||||
toErase.push_back(U);
|
||||
} else if (isa<CallOp>(U) && cast<CallOp>(U).callee() == "free") {
|
||||
} else if (isa<CallOp>(U) && cast<CallOp>(U).getCallee() == "free") {
|
||||
toErase.push_back(U);
|
||||
} else if (auto CO = dyn_cast<memref::CastOp>(U)) {
|
||||
toErase.push_back(U);
|
||||
|
|
|
@ -171,7 +171,7 @@ void ParallelLower::runOnOperation() {
|
|||
symbolTable.getSymbolTable(getOperation());
|
||||
|
||||
getOperation()->walk([&](mlir::CallOp bidx) {
|
||||
if (bidx.callee() == "cudaThreadSynchronize")
|
||||
if (bidx.getCallee() == "cudaThreadSynchronize")
|
||||
bidx.erase();
|
||||
});
|
||||
|
||||
|
@ -300,7 +300,7 @@ void ParallelLower::runOnOperation() {
|
|||
bz.setInsertionPointToStart(blockB);
|
||||
auto newAlloca = bz.create<LLVM::AllocaOp>(
|
||||
alop.getLoc(), LLVM::LLVMPointerType::get(PT.getElementType(), 0),
|
||||
alop.arraySize());
|
||||
alop.getArraySize());
|
||||
alop.replaceAllUsesWith((mlir::Value)bz.create<LLVM::AddrSpaceCastOp>(
|
||||
alop.getLoc(), PT, newAlloca));
|
||||
alop.erase();
|
||||
|
@ -394,7 +394,7 @@ void ParallelLower::runOnOperation() {
|
|||
});
|
||||
|
||||
container.walk([&](LLVM::CallOp call) {
|
||||
if (call.callee().getValue() == "cudaMemcpy") {
|
||||
if (call.getCallee().getValue() == "cudaMemcpy") {
|
||||
OpBuilder bz(call);
|
||||
auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1);
|
||||
bz.create<LLVM::MemcpyOp>(call.getLoc(), call.getOperand(0),
|
||||
|
@ -404,7 +404,7 @@ void ParallelLower::runOnOperation() {
|
|||
bz.create<ConstantIntOp>(call.getLoc(), 0, call.getType(0)));
|
||||
call.erase();
|
||||
}
|
||||
if (call.callee().getValue() == "cudaMemset") {
|
||||
if (call.getCallee().getValue() == "cudaMemset") {
|
||||
OpBuilder bz(call);
|
||||
auto falsev = bz.create<ConstantIntOp>(call.getLoc(), false, 1);
|
||||
bz.create<LLVM::MemsetOp>(call.getLoc(), call.getOperand(0),
|
||||
|
@ -414,7 +414,7 @@ void ParallelLower::runOnOperation() {
|
|||
call.replaceAllUsesWith(ArrayRef<Value>(vals));
|
||||
call.erase();
|
||||
}
|
||||
if (call.callee().getValue() == "cudaMalloc") {
|
||||
if (call.getCallee().getValue() == "cudaMalloc") {
|
||||
|
||||
Value vals[] = {call.getOperand(0)};
|
||||
call.replaceAllUsesWith(ArrayRef<Value>(vals));
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 30d87d4a5d02f00ef58ebc24a0ee5c6c370b8b4c
|
||||
Subproject commit 2709fd1520bca98667db9c10b3156cac892949bc
|
|
@ -2145,6 +2145,17 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
|
|||
|
||||
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
|
||||
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
(sr->getDecl()->getName() == "memmove" ||
|
||||
sr->getDecl()->getName() == "__builtin_memmove")) {
|
||||
std::vector<mlir::Value> args = {
|
||||
getLLVM(expr->getArg(0)), getLLVM(expr->getArg(1)),
|
||||
getLLVM(expr->getArg(2)), /*isVolatile*/
|
||||
builder.create<ConstantIntOp>(loc, false, 1)};
|
||||
builder.create<LLVM::MemmoveOp>(loc, args[0], args[1], args[2],
|
||||
args[3]);
|
||||
return ValueCategory(args[0], /*isReference*/ false);
|
||||
}
|
||||
if (sr->getDecl()->getIdentifier() &&
|
||||
(sr->getDecl()->getName() == "cudaMemcpy" ||
|
||||
sr->getDecl()->getName() == "cudaMemcpyToSymbol" ||
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
#include "clang/AST/Expr.h"
|
||||
|
@ -28,10 +28,9 @@ using namespace mlir;
|
|||
using namespace llvm;
|
||||
using namespace clang;
|
||||
|
||||
Operation *mlirclang::buildLinalgOp(const AbstractOperation *op, OpBuilder &b,
|
||||
SmallVectorImpl<mlir::Value> &input,
|
||||
SmallVectorImpl<mlir::Value> &output) {
|
||||
StringRef name = op->name;
|
||||
Operation *buildLinalgOp(StringRef name, OpBuilder &b,
|
||||
SmallVectorImpl<mlir::Value> &input,
|
||||
SmallVectorImpl<mlir::Value> &output) {
|
||||
if (name.compare("linalg.copy") == 0) {
|
||||
assert(input.size() == 1 && "linalg::copyOp requires 1 input");
|
||||
assert(output.size() == 1 && "linalg::CopyOp requires 1 output");
|
||||
|
@ -50,13 +49,11 @@ mlirclang::replaceFuncByOperation(FuncOp f, StringRef opName, OpBuilder &b,
|
|||
assert(ctx->isOperationRegistered(opName) &&
|
||||
"Provided lower_to opName should be registered.");
|
||||
|
||||
const AbstractOperation *op = AbstractOperation::lookup(opName, ctx);
|
||||
|
||||
if (opName.startswith("linalg"))
|
||||
return buildLinalgOp(op, b, input, output);
|
||||
return buildLinalgOp(opName, b, input, output);
|
||||
|
||||
// NOTE: The attributes of the provided FuncOp is ignored.
|
||||
OperationState opState(b.getUnknownLoc(), op->name, input,
|
||||
OperationState opState(b.getUnknownLoc(), opName, input,
|
||||
f.getCallableResults(), {});
|
||||
return b.createOperation(opState);
|
||||
}
|
||||
|
|
|
@ -42,11 +42,6 @@ replaceFuncByOperation(mlir::FuncOp f, llvm::StringRef opName,
|
|||
mlir::OpBuilder &b,
|
||||
llvm::SmallVectorImpl<mlir::Value> &input,
|
||||
llvm::SmallVectorImpl<mlir::Value> &output);
|
||||
|
||||
mlir::Operation *buildLinalgOp(const mlir::AbstractOperation *op,
|
||||
mlir::OpBuilder &b,
|
||||
llvm::SmallVectorImpl<mlir::Value> &input,
|
||||
llvm::SmallVectorImpl<mlir::Value> &output);
|
||||
} // namespace mlirclang
|
||||
|
||||
#endif
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/Affine/Passes.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
@ -388,6 +388,8 @@ int main(int argc, char **argv) {
|
|||
LLVM::LLVMPointerType::attachInterface<MemRefInsider>(context);
|
||||
LLVM::LLVMStructType::attachInterface<MemRefInsider>(context);
|
||||
MemRefType::attachInterface<PtrElementModel<MemRefType>>(context);
|
||||
LLVM::LLVMStructType::attachInterface<PtrElementModel<LLVM::LLVMStructType>>(
|
||||
context);
|
||||
|
||||
if (showDialects) {
|
||||
outs() << "Registered Dialects:\n";
|
||||
|
|
Loading…
Reference in New Issue