Update VisitExtVectorElementExpr to handle memref-abi and llvm-abi (#269)
* Fix ext_vector_type test; the test did not correctly checked the input * update ext_vector_type test to check for memref-abi equal to 0 and 1 * Update VisitExtVectorElementExpr to handle memref-abi and llvm-abi * Update tests * apply clang-format * update if else format * replace vector's type size_t by float within the test Co-authored-by: Jefferson Le Quellec <jefferson.lequellec@codeplay.com>
This commit is contained in:
parent
f6a9282c92
commit
5d40003a8b
|
@ -463,28 +463,56 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name,
|
|||
ValueCategory
|
||||
MLIRScanner::VisitExtVectorElementExpr(clang::ExtVectorElementExpr *expr) {
|
||||
auto base = Visit(expr->getBase());
|
||||
|
||||
SmallVector<uint32_t, 4> indices;
|
||||
expr->getEncodedElementAccess(indices);
|
||||
assert(indices.size() == 1 &&
|
||||
"The support for higher dimensions to be implemented.");
|
||||
auto loc = getMLIRLocation(expr->getExprLoc());
|
||||
auto idx = castToIndex(getMLIRLocation(expr->getAccessorLoc()),
|
||||
builder.create<ConstantIntOp>(loc, indices[0], 32));
|
||||
|
||||
assert(base.isReference);
|
||||
base.isReference = false;
|
||||
auto mt = base.val.getType().cast<MemRefType>();
|
||||
|
||||
const auto et = base.val.getType();
|
||||
assert(et.isa<LLVM::LLVMPointerType>() || et.isa<MemRefType>());
|
||||
|
||||
ValueCategory result = nullptr;
|
||||
const auto exprLoc = getMLIRLocation(expr->getExprLoc());
|
||||
const auto accLoc = getMLIRLocation(expr->getAccessorLoc());
|
||||
const mlir::Value idxs[2] = {
|
||||
builder.create<ConstantIntOp>(exprLoc, 0, 32),
|
||||
builder.create<ConstantIntOp>(exprLoc, indices[0], 32),
|
||||
};
|
||||
|
||||
if (const auto pt = et.dyn_cast<LLVM::LLVMPointerType>()) {
|
||||
auto pt0 =
|
||||
pt.getElementType().cast<mlir::LLVM::LLVMArrayType>().getElementType();
|
||||
base.val = builder.create<mlir::LLVM::GEPOp>(
|
||||
exprLoc, mlir::LLVM::LLVMPointerType::get(pt0, pt.getAddressSpace()),
|
||||
base.val, idxs);
|
||||
|
||||
result = ValueCategory(base.val, true);
|
||||
} else if (const auto mt = et.dyn_cast<MemRefType>()) {
|
||||
auto shape = std::vector<int64_t>(mt.getShape());
|
||||
|
||||
if (shape.size() == 1) {
|
||||
shape[0] = -1;
|
||||
} else {
|
||||
shape.erase(shape.begin());
|
||||
}
|
||||
|
||||
auto mt0 =
|
||||
mlir::MemRefType::get(shape, mt.getElementType(),
|
||||
MemRefLayoutAttrInterface(), mt.getMemorySpace());
|
||||
base.val = builder.create<polygeist::SubIndexOp>(loc, mt0, base.val,
|
||||
getConstantIndex(0));
|
||||
return CommonArrayLookup(loc, base, idx, base.isReference);
|
||||
base.val = builder.create<polygeist::SubIndexOp>(
|
||||
exprLoc, mt0, base.val, castToIndex(accLoc, idxs[0]));
|
||||
|
||||
result = CommonArrayLookup(exprLoc, base, castToIndex(accLoc, idxs[1]),
|
||||
base.isReference);
|
||||
} else {
|
||||
llvm_unreachable("Unexpected MLIR type received");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ValueCategory MLIRScanner::VisitConstantExpr(clang::ConstantExpr *expr) {
|
||||
|
@ -5108,8 +5136,12 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef,
|
|||
}
|
||||
if (!memRefABI || !allowMerge ||
|
||||
ET.isa<LLVM::LLVMPointerType, LLVM::LLVMArrayType,
|
||||
LLVM::LLVMFunctionType, LLVM::LLVMStructType>())
|
||||
return LLVM::LLVMFixedVectorType::get(ET, size);
|
||||
LLVM::LLVMFunctionType, LLVM::LLVMStructType>()) {
|
||||
if (mlir::LLVM::LLVMFixedVectorType::isValidElementType(ET)) {
|
||||
return mlir::LLVM::LLVMFixedVectorType::get(ET, size);
|
||||
}
|
||||
return mlir::LLVM::LLVMArrayType::get(ET, size);
|
||||
}
|
||||
if (implicitRef)
|
||||
*implicitRef = true;
|
||||
return mlir::MemRefType::get({size}, ET);
|
||||
|
|
|
@ -1,22 +1,41 @@
|
|||
// RUN: cgeist %s --function=* -S | FileCheck %s
|
||||
// RUN: cgeist %s --function=* -memref-abi=1 -S | FileCheck %s
|
||||
// RUN: cgeist %s --function=* -memref-abi=0 -S | FileCheck %s -check-prefix=CHECK2
|
||||
|
||||
typedef size_t size_t_vec __attribute__((ext_vector_type(3)));
|
||||
typedef float float_vec __attribute__((ext_vector_type(3)));
|
||||
|
||||
size_t evt(size_t_vec stv) {
|
||||
float evt(float_vec stv) {
|
||||
return stv.x;
|
||||
}
|
||||
|
||||
extern "C" const size_t_vec stv;
|
||||
size_t evt2() {
|
||||
extern "C" const float_vec stv;
|
||||
float evt2() {
|
||||
return stv.x;
|
||||
}
|
||||
|
||||
// CHECK: func.func @_Z3evtDv3_i(%arg0: memref<?x3xi32>) -> i32 attributes {llvm.linkage = #llvm.linkage<external>}
|
||||
// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x3xi32>
|
||||
// CHECK-NEXT: return %0 : i32
|
||||
// CHECK: memref.global @stv : memref<3xf32>
|
||||
// CHECK: func.func @_Z3evtDv3_f(%arg0: memref<?x3xf32>) -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
|
||||
// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref<?x3xf32>
|
||||
// CHECK-NEXT: return %0 : f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: func.func @_Z4evt2v() -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
|
||||
// CHECK-NEXT: %0 = memref.get_global @stv : memref<3xi32>
|
||||
// CHECK-NEXT: %1 = affine.load %0[0] : memref<3xi32>
|
||||
// CHECK-NEXT: return %1 : i32
|
||||
// CHECK: func.func @_Z4evt2v() -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
|
||||
// CHECK-NEXT: %0 = memref.get_global @stv : memref<3xf32>
|
||||
// CHECK-NEXT: %1 = affine.load %0[0] : memref<3xf32>
|
||||
// CHECK-NEXT: return %1 : f32
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK2: llvm.mlir.global external @stv() {addr_space = 0 : i32} : !llvm.array<3 x f32>
|
||||
// CHECK2: func.func @_Z3evtDv3_f(%arg0: !llvm.array<3 x f32>) -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
|
||||
// CHECK2-NEXT: %c1_i64 = arith.constant 1 : i64
|
||||
// CHECK2-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.array<3 x f32> : (i64) -> !llvm.ptr<array<3 x f32>>
|
||||
// CHECK2-NEXT: llvm.store %arg0, %0 : !llvm.ptr<array<3 x f32>>
|
||||
// CHECK2-NEXT: %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr<array<3 x f32>>) -> !llvm.ptr<f32>
|
||||
// CHECK2-NEXT: %2 = llvm.load %1 : !llvm.ptr<f32>
|
||||
// CHECK2-NEXT: return %2 : f32
|
||||
// CHECK2-NEXT: }
|
||||
// CHECK2: func.func @_Z4evt2v() -> f32 attributes {llvm.linkage = #llvm.linkage<external>} {
|
||||
// CHECK2-NEXT: %0 = llvm.mlir.addressof @stv : !llvm.ptr<array<3 x f32>>
|
||||
// CHECK2-NEXT: %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr<array<3 x f32>>) -> !llvm.ptr<f32>
|
||||
// CHECK2-NEXT: %2 = llvm.load %1 : !llvm.ptr<f32>
|
||||
// CHECK2-NEXT: return %2 : f32
|
||||
// CHECK2-NEXT: }
|
||||
|
||||
|
|
Loading…
Reference in New Issue