Update VisitExtVectorElementExpr to handle memref-abi and llvm-abi (#269)

* Fix ext_vector_type test; the test did not correctly checked the input

* update ext_vector_type test to check for memref-abi equal to 0 and 1

* Update VisitExtVectorElementExpr to handle memref-abi and llvm-abi

* Update tests

* apply clang-format

* update if else format

* replace vector's type size_t by float within the test

Co-authored-by: Jefferson Le Quellec <jefferson.lequellec@codeplay.com>
This commit is contained in:
jle-quel 2022-09-14 09:26:47 +02:00 committed by GitHub
parent f6a9282c92
commit 5d40003a8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 30 deletions

View File

@ -463,28 +463,56 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name,
ValueCategory
MLIRScanner::VisitExtVectorElementExpr(clang::ExtVectorElementExpr *expr) {
auto base = Visit(expr->getBase());
SmallVector<uint32_t, 4> indices;
expr->getEncodedElementAccess(indices);
assert(indices.size() == 1 &&
"The support for higher dimensions to be implemented.");
auto loc = getMLIRLocation(expr->getExprLoc());
auto idx = castToIndex(getMLIRLocation(expr->getAccessorLoc()),
builder.create<ConstantIntOp>(loc, indices[0], 32));
assert(base.isReference);
base.isReference = false;
auto mt = base.val.getType().cast<MemRefType>();
auto shape = std::vector<int64_t>(mt.getShape());
if (shape.size() == 1) {
shape[0] = -1;
const auto et = base.val.getType();
assert(et.isa<LLVM::LLVMPointerType>() || et.isa<MemRefType>());
ValueCategory result = nullptr;
const auto exprLoc = getMLIRLocation(expr->getExprLoc());
const auto accLoc = getMLIRLocation(expr->getAccessorLoc());
const mlir::Value idxs[2] = {
builder.create<ConstantIntOp>(exprLoc, 0, 32),
builder.create<ConstantIntOp>(exprLoc, indices[0], 32),
};
if (const auto pt = et.dyn_cast<LLVM::LLVMPointerType>()) {
auto pt0 =
pt.getElementType().cast<mlir::LLVM::LLVMArrayType>().getElementType();
base.val = builder.create<mlir::LLVM::GEPOp>(
exprLoc, mlir::LLVM::LLVMPointerType::get(pt0, pt.getAddressSpace()),
base.val, idxs);
result = ValueCategory(base.val, true);
} else if (const auto mt = et.dyn_cast<MemRefType>()) {
auto shape = std::vector<int64_t>(mt.getShape());
if (shape.size() == 1) {
shape[0] = -1;
} else {
shape.erase(shape.begin());
}
auto mt0 =
mlir::MemRefType::get(shape, mt.getElementType(),
MemRefLayoutAttrInterface(), mt.getMemorySpace());
base.val = builder.create<polygeist::SubIndexOp>(
exprLoc, mt0, base.val, castToIndex(accLoc, idxs[0]));
result = CommonArrayLookup(exprLoc, base, castToIndex(accLoc, idxs[1]),
base.isReference);
} else {
shape.erase(shape.begin());
llvm_unreachable("Unexpected MLIR type received");
}
auto mt0 =
mlir::MemRefType::get(shape, mt.getElementType(),
MemRefLayoutAttrInterface(), mt.getMemorySpace());
base.val = builder.create<polygeist::SubIndexOp>(loc, mt0, base.val,
getConstantIndex(0));
return CommonArrayLookup(loc, base, idx, base.isReference);
return result;
}
ValueCategory MLIRScanner::VisitConstantExpr(clang::ConstantExpr *expr) {
@ -5108,8 +5136,12 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
}
if (!memRefABI || !allowMerge ||
ET.isa<LLVM::LLVMPointerType, LLVM::LLVMArrayType,
LLVM::LLVMFunctionType, LLVM::LLVMStructType>())
return LLVM::LLVMFixedVectorType::get(ET, size);
LLVM::LLVMFunctionType, LLVM::LLVMStructType>()) {
if (mlir::LLVM::LLVMFixedVectorType::isValidElementType(ET)) {
return mlir::LLVM::LLVMFixedVectorType::get(ET, size);
}
return mlir::LLVM::LLVMArrayType::get(ET, size);
}
if (implicitRef)
*implicitRef = true;
return mlir::MemRefType::get({size}, ET);

View File

@ -1,22 +1,41 @@
// RUN: cgeist %s --function=* -S | FileCheck %s
// RUN: cgeist %s --function=* -memref-abi=1 -S | FileCheck %s
// RUN: cgeist %s --function=* -memref-abi=0 -S | FileCheck %s -check-prefix=CHECK2
typedef size_t size_t_vec __attribute__((ext_vector_type(3)));
typedef float float_vec __attribute__((ext_vector_type(3)));
size_t evt(size_t_vec stv) {
float evt(float_vec stv) {
return stv.x;
}
extern "C" const size_t_vec stv;
size_t evt2() {
extern "C" const float_vec stv;
float evt2() {
return stv.x;
}
// CHECK: func.func @_Z3evtDv3_i(%arg0: memref<?x3xi32>) -> i32 attributes {llvm.linkage = #llvm.linkage<external>}
// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x3xi32>
// CHECK-NEXT: return %0 : i32
// CHECK-NEXT: }
// CHECK: func.func @_Z4evt2v() -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
// CHECK-NEXT: %0 = memref.get_global @stv : memref<3xi32>
// CHECK-NEXT: %1 = affine.load %0[0] : memref<3xi32>
// CHECK-NEXT: return %1 : i32
// CHECK-NEXT: }
// CHECK: memref.global @stv : memref<3xf32>
// CHECK: func.func @_Z3evtDv3_f(%arg0: memref<?x3xf32>) -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x3xf32>
// CHECK-NEXT: return %0 : f32
// CHECK-NEXT: }
// CHECK: func.func @_Z4evt2v() -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
// CHECK-NEXT: %0 = memref.get_global @stv : memref<3xf32>
// CHECK-NEXT: %1 = affine.load %0[0] : memref<3xf32>
// CHECK-NEXT: return %1 : f32
// CHECK-NEXT: }
// CHECK2: llvm.mlir.global external @stv() {addr_space = 0 : i32} : !llvm.array<3 x f32>
// CHECK2: func.func @_Z3evtDv3_f(%arg0: !llvm.array<3 x f32>) -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
// CHECK2-NEXT: %c1_i64 = arith.constant 1 : i64
// CHECK2-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.array<3 x f32> : (i64) -> !llvm.ptr<array<3 x f32>>
// CHECK2-NEXT: llvm.store %arg0, %0 : !llvm.ptr<array<3 x f32>>
// CHECK2-NEXT: %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr<array<3 x f32>>) -> !llvm.ptr<f32>
// CHECK2-NEXT: %2 = llvm.load %1 : !llvm.ptr<f32>
// CHECK2-NEXT: return %2 : f32
// CHECK2-NEXT: }
// CHECK2: func.func @_Z4evt2v() -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
// CHECK2-NEXT: %0 = llvm.mlir.addressof @stv : !llvm.ptr<array<3 x f32>>
// CHECK2-NEXT: %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr<array<3 x f32>>) -> !llvm.ptr<f32>
// CHECK2-NEXT: %2 = llvm.load %1 : !llvm.ptr<f32>
// CHECK2-NEXT: return %2 : f32
// CHECK2-NEXT: }