Correct errors

This commit is contained in:
William S. Moses 2021-12-25 01:32:32 -05:00 committed by William Moses
parent 41966050a6
commit e9bfe63c16
9 changed files with 134 additions and 68 deletions

View File

@ -255,51 +255,23 @@ public:
if (!srcOp)
return failure();
auto preMemRefType = srcOp.source().getType().cast<MemRefType>();
auto srcMemRefType = op.source().getType().cast<MemRefType>();
auto resMemRefType = op.result().getType().cast<MemRefType>();
// Check if there are multiple users of the dynamically sized memory
if (!op.source().hasOneUse())
return failure();
// Check that the source op indeed is a dynamically indexed memory in the
// 0'th index.
if (srcMemRefType.getShape()[0] != -1)
return failure();
// Check that this is indeed a rank reducing operation
if (srcMemRefType.getShape().size() !=
(resMemRefType.getShape().size() + 1))
return failure();
// Check that there is not a downstream cast of subindex result. This is a
// bit dubious, but allowing cast canonicalizations - when possible - to
// convert subindexes will ultimately result in fewer memref.subview
// operations to be inferred.
for (auto user : op.getResult().getUsers()) {
if (isa<memref::CastOp>(user))
return failure();
}
for (auto it : llvm::zip(srcMemRefType.getShape().drop_front(),
resMemRefType.getShape())) {
if (std::get<0>(it) != std::get<1>(it))
return failure();
}
// Check that we're indexing into the 0'th index in the 2nd subindex op
auto constIdx = op.index().getDefiningOp<arith::ConstantOp>();
if (!constIdx)
return failure();
auto constValue = constIdx.getValue().dyn_cast<IntegerAttr>();
if (!constValue || !constValue.getType().isa<IndexType>() ||
constValue.getValue().getZExtValue() != 0)
// Check that the previous op is the same rank.
if (srcMemRefType.getShape().size() != preMemRefType.getShape().size())
return failure();
// Valid optimization target; perform the substitution.
rewriter.replaceOpWithNewOp<SubIndexOp>(op, op.result().getType(),
srcOp.source(), srcOp.index());
rewriter.eraseOp(srcOp);
rewriter.replaceOpWithNewOp<SubIndexOp>(
op, op.result().getType(), srcOp.source(),
rewriter.create<arith::AddIOp>(op.getLoc(), op.index(), srcOp.index()));
return success();
}
};
@ -708,10 +680,10 @@ struct SelectOfSubIndex : public OpRewritePattern<SelectOp> {
void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results
.insert<CastOfSubIndex, SubIndexOpMemRefCastFolder, SubIndex2, SubToCast,
SimplifySubViewUsers, SimplifySubIndexUsers, SelectOfCast,
SelectOfSubIndex, SubToSubView, RedundantDynSubIndex>(context);
results.insert<CastOfSubIndex, SubIndexOpMemRefCastFolder, SubIndex2,
SubToCast, SimplifySubViewUsers, SimplifySubIndexUsers,
SelectOfCast, SelectOfSubIndex, RedundantDynSubIndex>(context);
// Disabled: SubToSubView
}
/// Simplify memref2pointer(cast(x)) to memref2pointer(x)

View File

@ -1157,7 +1157,8 @@ struct MoveWhileDown3 : public OpRewritePattern<WhileOp> {
// TODO generalize to any non memory effecting op
if (auto idx =
std::get<1>(pair).getDefiningOp<MemoryEffectOpInterface>()) {
if (idx.hasNoEffect()) {
if (idx.hasNoEffect() &&
!llvm::is_contained(newOps, std::get<1>(pair))) {
Operation *cloned = std::get<1>(pair).getDefiningOp();
if (!std::get<1>(pair).hasOneUse()) {
cloned = std::get<1>(pair).getDefiningOp()->clone();

View File

@ -505,8 +505,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
if (successor == target) {
OpBuilder builder(terminator);
auto vfalse = builder.create<mlir::ConstantOp>(
builder.getUnknownLoc(), i1Ty, builder.getIntegerAttr(i1Ty, 0));
auto vfalse = builder.create<arith::ConstantIntOp>(
builder.getUnknownLoc(), false, 1);
std::vector<Value> args = {vfalse};
for (auto arg : header->getArguments())
@ -557,8 +557,8 @@ void LoopRestructure::runOnRegion(DominanceInfo &domInfo, Region &region) {
if (successor == header) {
OpBuilder builder(terminator);
auto vtrue = builder.create<mlir::ConstantOp>(
builder.getUnknownLoc(), i1Ty, builder.getIntegerAttr(i1Ty, 1));
auto vtrue = builder.create<arith::ConstantIntOp>(
builder.getUnknownLoc(), true, 1);
if (auto op = dyn_cast<BranchOp>(terminator)) {
std::vector<Value> args(op.getOperands().begin(),

View File

@ -1,5 +1,5 @@
// RUN: polygeist-opt --canonicalize --split-input-file %s | FileCheck %s
// XFAIL: *
// CHECK: func @main(%arg0: index) -> memref<30xi32> {
// CHECK: %0 = memref.alloca() : memref<30x30xi32>
// CHECK: %1 = memref.subview %0[%arg0, 0] [1, 30] [1, 1] : memref<30x30xi32> to memref<30xi32>

View File

@ -0,0 +1,46 @@
// RUN: polygeist-opt --canonicalize-scf-for --split-input-file %s | FileCheck %s
module {
func private @cmp() -> i1
func @_Z4div_Pi(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %arg2: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c3_i64 = arith.constant 3 : index
%1:3 = scf.while (%arg3 = %c0_i32) : (i32) -> (i32, index, index) {
%2 = arith.index_cast %arg3 : i32 to index
%3 = arith.addi %2, %c3_i64 : index
%5 = call @cmp() : () -> i1
scf.condition(%5) %arg3, %3, %2 : i32, index, index
} do {
^bb0(%arg3: i32, %arg4: index, %arg5: index): // no predecessors
%parg3 = arith.addi %arg3, %c1_i32 : i32
%3 = memref.load %arg0[%arg5] : memref<?xi32>
memref.store %3, %arg1[%arg4] : memref<?xi32>
scf.yield %parg3 : i32
}
return
}
}
// CHECK: func @_Z4div_Pi(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %arg2: i32) {
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
// CHECK-DAG: %c1_i32 = arith.constant 1 : i32
// CHECK-DAG: %c3 = arith.constant 3 : index
// CHECK-NEXT: %0 = scf.while (%arg3 = %c0_i32) : (i32) -> i32 {
// CHECK-NEXT: %1 = call @cmp() : () -> i1
// CHECK-NEXT: scf.condition(%1) %arg3 : i32
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%arg3: i32): // no predecessors
// CHECK-NEXT: %1 = arith.index_cast %arg3 : i32 to index
// CHECK-NEXT: %2 = arith.index_cast %arg3 : i32 to index
// CHECK-NEXT: %3 = arith.addi %1, %c3 : index
// CHECK-NEXT: %4 = arith.addi %arg3, %c1_i32 : i32
// CHECK-NEXT: %5 = memref.load %arg0[%2] : memref<?xi32>
// CHECK-NEXT: memref.store %5, %arg1[%3] : memref<?xi32>
// CHECK-NEXT: scf.yield %4 : i32
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }

View File

@ -351,7 +351,7 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name,
if (!alloc) {
alloc = abuilder.create<mlir::LLVM::AllocaOp>(
varLoc, mlir::LLVM::LLVMPointerType::get(t, memspace),
builder.create<ConstantIntOp>(varLoc, 1, 64), 0);
abuilder.create<ConstantIntOp>(varLoc, 1, 64), 0);
if (t.isa<mlir::IntegerType>()) {
abuilder.create<LLVM::StoreOp>(
varLoc, abuilder.create<mlir::LLVM::UndefOp>(varLoc, t), alloc);
@ -360,12 +360,12 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name,
// LLVM::LLVMPointerType::get(LLVM::LLVMArrayType::get(t, 1)), alloc);
}
} else {
mlir::Value idxs[] = {getConstantIndex(0)};
mr = mlir::MemRefType::get(1, t, {}, memspace);
alloc = abuilder.create<mlir::memref::AllocaOp>(varLoc, mr);
alloc = abuilder.create<mlir::memref::CastOp>(
varLoc, alloc, mlir::MemRefType::get(-1, t, {}, memspace));
if (t.isa<mlir::IntegerType>()) {
mlir::Value idxs[] = {abuilder.create<ConstantIndexOp>(loc, 0)};
abuilder.create<mlir::memref::StoreOp>(
varLoc, abuilder.create<mlir::LLVM::UndefOp>(varLoc, t), alloc,
idxs);
@ -3509,17 +3509,25 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
}
case clang::BinaryOperator::Opcode::BO_NE: {
auto lhs_v = lhs.getValue(builder);
auto rhs_v = rhs.getValue(builder);
if (auto mt = lhs_v.getType().dyn_cast<mlir::MemRefType>()) {
lhs_v = builder.create<polygeist::Memref2PointerOp>(
loc, LLVM::LLVMPointerType::get(mt.getElementType()), lhs_v);
}
if (auto mt = rhs_v.getType().dyn_cast<mlir::MemRefType>()) {
rhs_v = builder.create<polygeist::Memref2PointerOp>(
loc, LLVM::LLVMPointerType::get(mt.getElementType()), rhs_v);
}
mlir::Value res;
if (lhs_v.getType().isa<mlir::FloatType>()) {
res = builder.create<arith::CmpFOp>(loc, CmpFPredicate::UNE, lhs_v,
rhs.getValue(builder));
res =
builder.create<arith::CmpFOp>(loc, CmpFPredicate::UNE, lhs_v, rhs_v);
} else if (auto pt =
lhs_v.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
res = builder.create<LLVM::ICmpOp>(loc, mlir::LLVM::ICmpPredicate::ne,
lhs_v, rhs.getValue(builder));
lhs_v, rhs_v);
} else {
res = builder.create<arith::CmpIOp>(loc, CmpIPredicate::ne, lhs_v,
rhs.getValue(builder));
res = builder.create<arith::CmpIOp>(loc, CmpIPredicate::ne, lhs_v, rhs_v);
}
return fixInteger(res);
}
@ -3603,6 +3611,14 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
case clang::BinaryOperator::Opcode::BO_Sub: {
auto lhs_v = lhs.getValue(builder);
auto rhs_v = rhs.getValue(builder);
if (auto mt = lhs_v.getType().dyn_cast<mlir::MemRefType>()) {
lhs_v = builder.create<polygeist::Memref2PointerOp>(
loc, LLVM::LLVMPointerType::get(mt.getElementType()), lhs_v);
}
if (auto mt = rhs_v.getType().dyn_cast<mlir::MemRefType>()) {
rhs_v = builder.create<polygeist::Memref2PointerOp>(
loc, LLVM::LLVMPointerType::get(mt.getElementType()), rhs_v);
}
if (lhs_v.getType().isa<mlir::FloatType>()) {
assert(rhs_v.getType() == lhs_v.getType());
return ValueCategory(builder.create<SubFOp>(loc, lhs_v, rhs_v),
@ -3624,15 +3640,6 @@ ValueCategory MLIRScanner::VisitBinaryOperator(clang::BinaryOperator *BO) {
builder.create<LLVM::PtrToIntOp>(
loc, getMLIRType(BO->getType()), rhs_v)),
/*isReference*/ false);
} else if (auto mt = lhs_v.getType().dyn_cast<mlir::MemRefType>()) {
llvm::errs() << " memref ptrtoint: " << mt << "\n";
return ValueCategory(
builder.create<SubIOp>(loc,
builder.create<LLVM::PtrToIntOp>(
loc, getMLIRType(BO->getType()), lhs_v),
builder.create<LLVM::PtrToIntOp>(
loc, getMLIRType(BO->getType()), rhs_v)),
/*isReference*/ false);
} else {
return ValueCategory(builder.create<SubIOp>(loc, lhs_v, rhs_v),
/*isReference*/ false);

View File

@ -8,10 +8,10 @@ float* zmem(int n) {
}
// CHECK: func @zmem(%arg0: i32) -> memref<?xf32> attributes {llvm.linkage = #llvm.linkage<external>} {
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %c4 = arith.constant 4 : index
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-DAG: %cst = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %c4 = arith.constant 4 : index
// CHECK-DAG: %c0 = arith.constant 0 : index
// CHECK-DAG: %c1 = arith.constant 1 : index
// CHECK-NEXT: %0 = arith.extui %arg0 : i32 to i64
// CHECK-NEXT: %1 = arith.index_cast %0 : i64 to index
// CHECK-NEXT: %2 = arith.muli %1, %c4 : index

View File

@ -42,8 +42,8 @@ void lt_kernel_cuda(MTensorIterator& iter) {
}
// CHECK: func @lt_kernel_cuda(%arg0: !llvm.ptr<struct<(struct<(ptr<struct<(i8, i8)>>)>)>>) attributes {llvm.linkage = #llvm.linkage<external>} {
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %c1_i64 = arith.constant 1 : i64
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
// CHECK-DAG: %c1_i64 = arith.constant 1 : i64
// CHECK-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.struct<(ptr<struct<(struct<(ptr<struct<(i8, i8)>>)>)>>)> : (i64) -> !llvm.ptr<struct<(ptr<struct<(struct<(ptr<struct<(i8, i8)>>)>)>>)>>
// CHECK-NEXT: %1 = llvm.alloca %c1_i64 x !llvm.struct<(ptr<struct<(struct<(ptr<struct<(i8, i8)>>)>)>>)> : (i64) -> !llvm.ptr<struct<(ptr<struct<(struct<(ptr<struct<(i8, i8)>>)>)>>)>>
// CHECK-NEXT: %2 = call @_ZNK15MTensorIterator11input_dtypeEv(%arg0) : (!llvm.ptr<struct<(struct<(ptr<struct<(i8, i8)>>)>)>>) -> i8
@ -65,8 +65,8 @@ void lt_kernel_cuda(MTensorIterator& iter) {
// CHECK-NEXT: return %2 : i8
// CHECK-NEXT: }
// CHECK-NEXT: func private @_ZZ14lt_kernel_cudaENK3$_0clEv(%arg0: !llvm.ptr<struct<(ptr<struct<(struct<(ptr<struct<(i8, i8)>>)>)>>)>>) attributes {llvm.linkage = #llvm.linkage<internal>} {
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %c1_i64 = arith.constant 1 : i64
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
// CHECK-DAG: %c1_i64 = arith.constant 1 : i64
// CHECK-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.struct<(i8)> : (i64) -> !llvm.ptr<struct<(i8)>>
// CHECK-NEXT: %1 = llvm.alloca %c1_i64 x !llvm.struct<(i8)> : (i64) -> !llvm.ptr<struct<(i8)>>
// CHECK-NEXT: %2 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr<struct<(ptr<struct<(struct<(ptr<struct<(i8, i8)>>)>)>>)>>, i32, i32) -> !llvm.ptr<ptr<struct<(struct<(ptr<struct<(i8, i8)>>)>)>>>

View File

@ -0,0 +1,40 @@
// RUN: mlir-clang %s --function=* -S | FileCheck %s
int MAX_DIMS;
struct A {
int x;
double y;
};
void div_(int* sizes) {
A data[25];
for (int i=0; i < MAX_DIMS; ++i) {
data[i].x = sizes[i];
}
}
// CHECK: func @_Z4div_Pi(%arg0: memref<?xi32>) attributes {llvm.linkage = #llvm.linkage<external>} {
// CHECK-DAG: %c1_i32 = arith.constant 1 : i32
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
// CHECK-DAG: %c1_i64 = arith.constant 1 : i64
// CHECK-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.array<25 x struct<(i32, f64)>> : (i64) -> !llvm.ptr<array<25 x struct<(i32, f64)>>>
// CHECK-NEXT: %1 = memref.get_global @MAX_DIMS : memref<1xi32>
// CHECK-NEXT: %2 = scf.while (%arg1 = %c0_i32) : (i32) -> i32 {
// CHECK-NEXT: %3 = affine.load %1[0] : memref<1xi32>
// CHECK-NEXT: %4 = arith.cmpi ult, %arg1, %3 : i32
// CHECK-NEXT: scf.condition(%4) %arg1 : i32
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%arg1: i32): // no predecessors
// CHECK-NEXT: %3 = arith.index_cast %arg1 : i32 to index
// CHECK-NEXT: %4 = arith.index_cast %3 : index to i64
// CHECK-NEXT: %5 = llvm.getelementptr %0[%c0_i32, %c0_i32] : (!llvm.ptr<array<25 x struct<(i32, f64)>>>, i32, i32) -> !llvm.ptr<struct<(i32, f64)>>
// CHECK-NEXT: %6 = llvm.getelementptr %5[%4] : (!llvm.ptr<struct<(i32, f64)>>, i64) -> !llvm.ptr<struct<(i32, f64)>>
// CHECK-NEXT: %7 = llvm.getelementptr %6[%c0_i32, %c0_i32] : (!llvm.ptr<struct<(i32, f64)>>, i32, i32) -> !llvm.ptr<i32>
// CHECK-NEXT: %8 = memref.load %arg0[%3] : memref<?xi32>
// CHECK-NEXT: llvm.store %8, %7 : !llvm.ptr<i32>
// CHECK-NEXT: %9 = arith.addi %arg1, %c1_i32 : i32
// CHECK-NEXT: scf.yield %9 : i32
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: }