Restore ref to pointer

This commit is contained in:
William S. Moses 2021-08-19 00:20:15 -04:00 committed by William Moses
parent d5f446ce54
commit 2f3a98b131
5 changed files with 101 additions and 46 deletions

View File

@ -56,6 +56,8 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name,
} else { } else {
mr = mlir::MemRefType::get(1, t, {}, memspace); mr = mlir::MemRefType::get(1, t, {}, memspace);
alloc = abuilder.create<mlir::memref::AllocaOp>(loc, mr); alloc = abuilder.create<mlir::memref::AllocaOp>(loc, mr);
alloc = abuilder.create<mlir::memref::CastOp>(
loc, alloc, mlir::MemRefType::get(-1, t, {}, memspace));
} }
} else { } else {
auto mt = t.cast<mlir::MemRefType>(); auto mt = t.cast<mlir::MemRefType>();
@ -75,10 +77,16 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name,
} }
if (!alloc) { if (!alloc) {
assert(shape[0] != -1); if (pshape == -1)
shape[0] = 1;
mr = mlir::MemRefType::get(shape, mt.getElementType(), mt.getAffineMaps(), mr = mlir::MemRefType::get(shape, mt.getElementType(), mt.getAffineMaps(),
memspace); memspace);
alloc = abuilder.create<mlir::memref::AllocaOp>(loc, mr); alloc = abuilder.create<mlir::memref::AllocaOp>(loc, mr);
shape[0] = pshape;
alloc = abuilder.create<mlir::memref::CastOp>(
loc, alloc,
mlir::MemRefType::get(shape, mt.getElementType(), mt.getAffineMaps(),
memspace));
} }
} }
assert(alloc); assert(alloc);
@ -1260,14 +1268,26 @@ MLIRScanner::VisitConstructCommon(clang::CXXConstructExpr *cons, VarDecl *name,
Glob.getMLIRType(Glob.CGM.getContext().getLValueReferenceType( Glob.getMLIRType(Glob.CGM.getContext().getLValueReferenceType(
a->getType())) a->getType()))
.cast<MemRefType>(); .cast<MemRefType>();
auto shape = std::vector<int64_t>(mt.getShape());
auto pshape = shape[0];
if (pshape == -1)
shape[0] = 1;
assert(shape.size() == 2);
OpBuilder abuilder(builder.getContext()); OpBuilder abuilder(builder.getContext());
abuilder.setInsertionPointToStart(allocationScope); abuilder.setInsertionPointToStart(allocationScope);
auto alloc = abuilder.create<mlir::memref::AllocaOp>(loc, mt); auto alloc = abuilder.create<mlir::memref::AllocaOp>(
loc,
mlir::MemRefType::get(shape, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace()));
ValueWithOffsets(alloc, /*isRef*/ true) ValueWithOffsets(alloc, /*isRef*/ true)
.store(builder, arg, /*isArray*/ isArray); .store(builder, arg, /*isArray*/ isArray);
val = alloc; shape[0] = pshape;
val = builder.create<mlir::memref::CastOp>(
loc, alloc,
mlir::MemRefType::get(shape, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace()));
} else } else
val = arg.getValue(builder); val = arg.getValue(builder);
} else { } else {
@ -1278,17 +1298,27 @@ MLIRScanner::VisitConstructCommon(clang::CXXConstructExpr *cons, VarDecl *name,
Glob.getMLIRType(Glob.CGM.getContext().getLValueReferenceType( Glob.getMLIRType(Glob.CGM.getContext().getLValueReferenceType(
a->getType())) a->getType()))
.cast<MemRefType>(); .cast<MemRefType>();
auto shape = std::vector<int64_t>(mt.getShape());
auto pshape = shape[0];
if (pshape == -1)
shape[0] = 1;
assert(shape.size() == 2);
OpBuilder abuilder(builder.getContext()); OpBuilder abuilder(builder.getContext());
abuilder.setInsertionPointToStart(allocationScope); abuilder.setInsertionPointToStart(allocationScope);
auto alloc = abuilder.create<mlir::memref::AllocaOp>(loc, mt); auto alloc = abuilder.create<mlir::memref::AllocaOp>(loc, mt);
ValueWithOffsets(alloc, /*isRef*/ true) ValueWithOffsets(alloc, /*isRef*/ true)
.store(builder, arg, /*isArray*/ isArray); .store(builder, arg, /*isArray*/ isArray);
toRestore.emplace_back(ValueWithOffsets(alloc, /*isRef*/ true), arg); toRestore.emplace_back(ValueWithOffsets(alloc, /*isRef*/ true), arg);
val = alloc; shape[0] = pshape;
val = builder.create<memref::CastOp>(
loc, alloc,
MemRefType::get(shape, mt.getElementType(), mt.getAffineMaps(),
mt.getMemorySpace()));
} else } else
val = arg.val; val = arg.val;
/*
if (!isArray) if (!isArray)
if (auto mt = val.getType().dyn_cast<MemRefType>()) { if (auto mt = val.getType().dyn_cast<MemRefType>()) {
auto shape = std::vector<int64_t>(mt.getShape()); auto shape = std::vector<int64_t>(mt.getShape());
@ -1298,6 +1328,7 @@ MLIRScanner::VisitConstructCommon(clang::CXXConstructExpr *cons, VarDecl *name,
mlir::MemRefType::get(shape, mt.getElementType(), mlir::MemRefType::get(shape, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace())); mt.getAffineMaps(), mt.getMemorySpace()));
} }
*/
} }
args.push_back(val); args.push_back(val);
} }
@ -2379,13 +2410,26 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
auto mt = Glob.getMLIRType(Glob.CGM.getContext().getLValueReferenceType( auto mt = Glob.getMLIRType(Glob.CGM.getContext().getLValueReferenceType(
a->getType())) a->getType()))
.cast<MemRefType>(); .cast<MemRefType>();
auto shape = std::vector<int64_t>(mt.getShape());
assert(shape.size() == 2);
auto pshape = shape[0];
if (pshape == -1)
shape[0] = 1;
OpBuilder abuilder(builder.getContext()); OpBuilder abuilder(builder.getContext());
abuilder.setInsertionPointToStart(allocationScope); abuilder.setInsertionPointToStart(allocationScope);
auto alloc = abuilder.create<mlir::memref::AllocaOp>(loc, mt); auto alloc = abuilder.create<mlir::memref::AllocaOp>(
loc,
mlir::MemRefType::get(shape, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace()));
ValueWithOffsets(alloc, /*isRef*/ true) ValueWithOffsets(alloc, /*isRef*/ true)
.store(builder, arg, /*isArray*/ isArray); .store(builder, arg, /*isArray*/ isArray);
val = alloc; shape[0] = pshape;
val = builder.create<mlir::memref::CastOp>(
loc, alloc,
mlir::MemRefType::get(shape, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace()));
} else } else
val = arg.getValue(builder); val = arg.getValue(builder);
} else { } else {
@ -2404,26 +2448,32 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) {
llvm::errs() << " arg.val: " << arg.val << "\n"; llvm::errs() << " arg.val: " << arg.val << "\n";
llvm::errs() << " mt: " << mt << "\n"; llvm::errs() << " mt: " << mt << "\n";
} }
auto pshape = shape[0];
if (shape.size() == 2)
if (pshape == -1)
shape[0] = 1;
OpBuilder abuilder(builder.getContext()); OpBuilder abuilder(builder.getContext());
abuilder.setInsertionPointToStart(allocationScope); abuilder.setInsertionPointToStart(allocationScope);
auto alloc = abuilder.create<mlir::memref::AllocaOp>(loc, mt); auto alloc = abuilder.create<mlir::memref::AllocaOp>(
loc,
mlir::MemRefType::get(shape, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace()));
ValueWithOffsets(alloc, /*isRef*/ true) ValueWithOffsets(alloc, /*isRef*/ true)
.store(builder, arg, /*isArray*/ isArray); .store(builder, arg, /*isArray*/ isArray);
toRestore.emplace_back(ValueWithOffsets(alloc, /*isRef*/ true), arg); toRestore.emplace_back(ValueWithOffsets(alloc, /*isRef*/ true), arg);
val = alloc; if (shape.size() == 2) {
shape[0] = pshape;
val = builder.create<memref::CastOp>(
loc, alloc,
MemRefType::get(shape, mt.getElementType(), mt.getAffineMaps(),
mt.getMemorySpace()));
} else {
val = alloc;
}
} else } else
val = arg.val; val = arg.val;
if (!isArray)
if (auto mt = val.getType().dyn_cast<MemRefType>()) {
auto shape = std::vector<int64_t>(mt.getShape());
shape[0] = 1;
val = builder.create<memref::CastOp>(
loc, val,
mlir::MemRefType::get(shape, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace()));
}
} }
assert(val); assert(val);
/* /*
@ -4538,7 +4588,7 @@ mlir::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction(const FunctionDecl *FD) {
getMLIRType(CC->getThisObjectType(), &isArray); getMLIRType(CC->getThisObjectType(), &isArray);
if (auto mt = t.dyn_cast<MemRefType>()) { if (auto mt = t.dyn_cast<MemRefType>()) {
auto shape = std::vector<int64_t>(mt.getShape()); auto shape = std::vector<int64_t>(mt.getShape());
shape[0] = 1; //shape[0] = 1;
t = mlir::MemRefType::get(shape, mt.getElementType(), t = mlir::MemRefType::get(shape, mt.getElementType(),
mt.getAffineMaps(), mt.getMemorySpace()); mt.getAffineMaps(), mt.getMemorySpace());
} }
@ -4934,7 +4984,7 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
} }
if (isa<clang::PointerType, clang::ReferenceType>(t)) { if (isa<clang::PointerType, clang::ReferenceType>(t)) {
int64_t outer = (isa<clang::PointerType>(t)) ? -1 : 1; int64_t outer = (isa<clang::PointerType>(t)) ? -1 : -1;
auto PTT = isa<clang::PointerType>(t) ? cast<clang::PointerType>(t) auto PTT = isa<clang::PointerType>(t) ? cast<clang::PointerType>(t)
->getPointeeType() ->getPointeeType()
->getUnqualifiedDesugaredType() ->getUnqualifiedDesugaredType()

View File

@ -25,15 +25,17 @@ int create() {
// CHECK-NEXT: %1 = memref.alloca() : memref<1x2xi32> // CHECK-NEXT: %1 = memref.alloca() : memref<1x2xi32>
// CHECK-NEXT: affine.store %c0_i32, %1[0, 0] : memref<1x2xi32> // CHECK-NEXT: affine.store %c0_i32, %1[0, 0] : memref<1x2xi32>
// CHECK-NEXT: affine.store %c1_i32, %1[0, 1] : memref<1x2xi32> // CHECK-NEXT: affine.store %c1_i32, %1[0, 1] : memref<1x2xi32>
// CHECK-NEXT: call @byval(%1, %c2_i32, %0) : (memref<1x2xi32>, i32, memref<1x2xi32>) -> () // CHECK-NEXT: %2 = memref.cast %1 : memref<1x2xi32> to memref<?x2xi32>
// CHECK-NEXT: %2 = affine.load %0[0, 0] : memref<1x2xi32> // CHECK-NEXT: %3 = memref.cast %0 : memref<1x2xi32> to memref<?x2xi32>
// CHECK-NEXT: return %2 : i32 // CHECK-NEXT: call @byval(%2, %c2_i32, %3) : (memref<?x2xi32>, i32, memref<?x2xi32>) -> ()
// CHECK-NEXT: %4 = affine.load %0[0, 0] : memref<1x2xi32>
// CHECK-NEXT: return %4 : i32
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK: builtin.func @byval(%arg0: memref<1x2xi32>, %arg1: i32, %arg2: memref<1x2xi32>) { // CHECK: builtin.func @byval(%arg0: memref<?x2xi32>, %arg1: i32, %arg2: memref<?x2xi32>) {
// CHECK-NEXT: affine.store %arg1, %arg0[0, 1] : memref<1x2xi32> // CHECK-NEXT: affine.store %arg1, %arg0[0, 1] : memref<?x2xi32>
// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<1x2xi32> // CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x2xi32>
// CHECK-NEXT: affine.store %0, %arg2[0, 0] : memref<1x2xi32> // CHECK-NEXT: affine.store %0, %arg2[0, 0] : memref<?x2xi32>
// CHECK-NEXT: %1 = affine.load %arg0[0, 1] : memref<1x2xi32> // CHECK-NEXT: %1 = affine.load %arg0[0, 1] : memref<?x2xi32>
// CHECK-NEXT: affine.store %1, %arg2[0, 1] : memref<1x2xi32> // CHECK-NEXT: affine.store %1, %arg2[0, 1] : memref<?x2xi32>
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }

View File

@ -25,14 +25,15 @@ int create() {
// CHECK-NEXT: affine.store %c0_i32, %1[0, 0] : memref<1x2xi32> // CHECK-NEXT: affine.store %c0_i32, %1[0, 0] : memref<1x2xi32>
// CHECK-NEXT: affine.store %c1_i32, %1[0, 1] : memref<1x2xi32> // CHECK-NEXT: affine.store %c1_i32, %1[0, 1] : memref<1x2xi32>
// CHECK-NEXT: %2 = memref.cast %1 : memref<1x2xi32> to memref<?x2xi32> // CHECK-NEXT: %2 = memref.cast %1 : memref<1x2xi32> to memref<?x2xi32>
// CHECK-NEXT: call @byval(%2, %c2_i32, %0) : (memref<?x2xi32>, i32, memref<1x2xi32>) -> () // CHECK-NEXT: %3 = memref.cast %0 : memref<1x2xi32> to memref<?x2xi32>
// CHECK-NEXT: %3 = affine.load %0[0, 0] : memref<1x2xi32> // CHECK-NEXT: call @byval(%2, %c2_i32, %3) : (memref<?x2xi32>, i32, memref<?x2xi32>) -> ()
// CHECK-NEXT: return %3 : i32 // CHECK-NEXT: %4 = affine.load %0[0, 0] : memref<1x2xi32>
// CHECK-NEXT: return %4 : i32
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK: builtin.func @byval(%arg0: memref<?x2xi32>, %arg1: i32, %arg2: memref<1x2xi32>) { // CHECK: builtin.func @byval(%arg0: memref<?x2xi32>, %arg1: i32, %arg2: memref<?x2xi32>) {
// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x2xi32> // CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x2xi32>
// CHECK-NEXT: affine.store %0, %arg2[0, 0] : memref<1x2xi32> // CHECK-NEXT: affine.store %0, %arg2[0, 0] : memref<?x2xi32>
// CHECK-NEXT: %1 = affine.load %arg0[0, 1] : memref<?x2xi32> // CHECK-NEXT: %1 = affine.load %arg0[0, 1] : memref<?x2xi32>
// CHECK-NEXT: affine.store %1, %arg2[0, 1] : memref<1x2xi32> // CHECK-NEXT: affine.store %1, %arg2[0, 1] : memref<?x2xi32>
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }

View File

@ -17,14 +17,15 @@ void kernel_deriche() {
// CHECK: builtin.func @kernel_deriche() { // CHECK: builtin.func @kernel_deriche() {
// CHECK-NEXT: %c32_i32 = constant 32 : i32 // CHECK-NEXT: %c32_i32 = constant 32 : i32
// CHECK-NEXT: %0 = memref.alloca() : memref<1xi32> // CHECK-NEXT: %0 = memref.alloca() : memref<1xi32>
// CHECK-NEXT: %1 = memref.cast %0 : memref<1xi32> to memref<?xi32>
// CHECK-NEXT: affine.store %c32_i32, %0[0] : memref<1xi32> // CHECK-NEXT: affine.store %c32_i32, %0[0] : memref<1xi32>
// CHECK-NEXT: call @sub(%0) : (memref<1xi32>) -> () // CHECK-NEXT: call @sub(%1) : (memref<?xi32>) -> ()
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK: builtin.func @sub(%arg0: memref<1xi32>) { // CHECK: builtin.func @sub(%arg0: memref<?xi32>) {
// CHECK-NEXT: %c1_i32 = constant 1 : i32 // CHECK-NEXT: %c1_i32 = constant 1 : i32
// CHECK-NEXT: %0 = affine.load %arg0[0] : memref<1xi32> // CHECK-NEXT: %0 = affine.load %arg0[0] : memref<?xi32>
// CHECK-NEXT: %1 = addi %0, %c1_i32 : i32 // CHECK-NEXT: %1 = addi %0, %c1_i32 : i32
// CHECK-NEXT: affine.store %1, %arg0[0] : memref<1xi32> // CHECK-NEXT: affine.store %1, %arg0[0] : memref<?xi32>
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }

View File

@ -21,18 +21,19 @@ void kernel_deriche() {
// CHECK: builtin.func @kernel_deriche() { // CHECK: builtin.func @kernel_deriche() {
// CHECK-NEXT: %c32_i32 = constant 32 : i32 // CHECK-NEXT: %c32_i32 = constant 32 : i32
// CHECK-NEXT: %0 = memref.alloca() : memref<1x2xi32> // CHECK-NEXT: %0 = memref.alloca() : memref<1x2xi32>
// CHECK-NEXT: call @_ZN4pairC1Ev(%0) : (memref<1x2xi32>) -> () // CHECK-NEXT: %1 = memref.cast %0 : memref<1x2xi32> to memref<?x2xi32>
// CHECK-NEXT: call @_ZN4pairC1Ev(%1) : (memref<?x2xi32>) -> ()
// CHECK-NEXT: affine.store %c32_i32, %0[0, 0] : memref<1x2xi32> // CHECK-NEXT: affine.store %c32_i32, %0[0, 0] : memref<1x2xi32>
// CHECK-NEXT: call @sub(%0) : (memref<1x2xi32>) -> () // CHECK-NEXT: call @sub(%1) : (memref<?x2xi32>) -> ()
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK: builtin.func @_ZN4pairC1Ev(%arg0: memref<1x2xi32>) { // CHECK: builtin.func @_ZN4pairC1Ev(%arg0: memref<?x2xi32>) {
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK: builtin.func @sub(%arg0: memref<1x2xi32>) { // CHECK: builtin.func @sub(%arg0: memref<?x2xi32>) {
// CHECK-NEXT: %c1_i32 = constant 1 : i32 // CHECK-NEXT: %c1_i32 = constant 1 : i32
// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<1x2xi32> // CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x2xi32>
// CHECK-NEXT: %1 = addi %0, %c1_i32 : i32 // CHECK-NEXT: %1 = addi %0, %c1_i32 : i32
// CHECK-NEXT: affine.store %1, %arg0[0, 0] : memref<1x2xi32> // CHECK-NEXT: affine.store %1, %arg0[0, 0] : memref<?x2xi32>
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }