From 5d40003a8ba0e255523dde17c1e2e10bfbb4405a Mon Sep 17 00:00:00 2001 From: jle-quel Date: Wed, 14 Sep 2022 09:26:47 +0200 Subject: [PATCH] Update VisitExtVectorElementExpr to handle memref-abi and llvm-abi (#269) * Fix ext_vector_type test; the test did not correctly checked the input * update ext_vector_type test to check for memref-abi equal to 0 and 1 * Update VisitExtVectorElementExpr to handle memref-abi and llvm-abi * Update tests * apply clang-format * update if else format * replace vector's type size_t by float within the test Co-authored-by: Jefferson Le Quellec --- tools/cgeist/Lib/clang-mlir.cc | 64 ++++++++++++++----- .../Test/Verification/ext_vector_type.cpp | 47 ++++++++++---- 2 files changed, 81 insertions(+), 30 deletions(-) diff --git a/tools/cgeist/Lib/clang-mlir.cc b/tools/cgeist/Lib/clang-mlir.cc index 051a8bc..bb861f9 100644 --- a/tools/cgeist/Lib/clang-mlir.cc +++ b/tools/cgeist/Lib/clang-mlir.cc @@ -463,28 +463,56 @@ mlir::Value MLIRScanner::createAllocOp(mlir::Type t, VarDecl *name, ValueCategory MLIRScanner::VisitExtVectorElementExpr(clang::ExtVectorElementExpr *expr) { auto base = Visit(expr->getBase()); + SmallVector indices; expr->getEncodedElementAccess(indices); assert(indices.size() == 1 && "The support for higher dimensions to be implemented."); - auto loc = getMLIRLocation(expr->getExprLoc()); - auto idx = castToIndex(getMLIRLocation(expr->getAccessorLoc()), - builder.create(loc, indices[0], 32)); + assert(base.isReference); base.isReference = false; - auto mt = base.val.getType().cast(); - auto shape = std::vector(mt.getShape()); - if (shape.size() == 1) { - shape[0] = -1; + + const auto et = base.val.getType(); + assert(et.isa() || et.isa()); + + ValueCategory result = nullptr; + const auto exprLoc = getMLIRLocation(expr->getExprLoc()); + const auto accLoc = getMLIRLocation(expr->getAccessorLoc()); + const mlir::Value idxs[2] = { + builder.create(exprLoc, 0, 32), + builder.create(exprLoc, indices[0], 32), + }; + + if (const auto pt = et.dyn_cast()) { + auto pt0 = + pt.getElementType().cast().getElementType(); + base.val = builder.create( + exprLoc, mlir::LLVM::LLVMPointerType::get(pt0, pt.getAddressSpace()), + base.val, idxs); + + result = ValueCategory(base.val, true); + } else if (const auto mt = et.dyn_cast()) { + auto shape = std::vector(mt.getShape()); + + if (shape.size() == 1) { + shape[0] = -1; + } else { + shape.erase(shape.begin()); + } + + auto mt0 = + mlir::MemRefType::get(shape, mt.getElementType(), + MemRefLayoutAttrInterface(), mt.getMemorySpace()); + base.val = builder.create( + exprLoc, mt0, base.val, castToIndex(accLoc, idxs[0])); + + result = CommonArrayLookup(exprLoc, base, castToIndex(accLoc, idxs[1]), + base.isReference); } else { - shape.erase(shape.begin()); + llvm_unreachable("Unexpected MLIR type received"); } - auto mt0 = - mlir::MemRefType::get(shape, mt.getElementType(), - MemRefLayoutAttrInterface(), mt.getMemorySpace()); - base.val = builder.create(loc, mt0, base.val, - getConstantIndex(0)); - return CommonArrayLookup(loc, base, idx, base.isReference); + + return result; } ValueCategory MLIRScanner::VisitConstantExpr(clang::ConstantExpr *expr) { @@ -5108,8 +5136,12 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef, } if (!memRefABI || !allowMerge || ET.isa()) - return LLVM::LLVMFixedVectorType::get(ET, size); + LLVM::LLVMFunctionType, LLVM::LLVMStructType>()) { + if (mlir::LLVM::LLVMFixedVectorType::isValidElementType(ET)) { + return mlir::LLVM::LLVMFixedVectorType::get(ET, size); + } + return mlir::LLVM::LLVMArrayType::get(ET, size); + } if (implicitRef) *implicitRef = true; return mlir::MemRefType::get({size}, ET); diff --git a/tools/cgeist/Test/Verification/ext_vector_type.cpp b/tools/cgeist/Test/Verification/ext_vector_type.cpp index 1322640..975291d 100644 --- a/tools/cgeist/Test/Verification/ext_vector_type.cpp +++ b/tools/cgeist/Test/Verification/ext_vector_type.cpp @@ -1,22 +1,41 @@ -// RUN: cgeist %s --function=* -S | FileCheck %s +// RUN: cgeist %s --function=* -memref-abi=1 -S | FileCheck %s +// RUN: cgeist %s --function=* -memref-abi=0 -S | FileCheck %s -check-prefix=CHECK2 -typedef size_t size_t_vec __attribute__((ext_vector_type(3))); +typedef float float_vec __attribute__((ext_vector_type(3))); -size_t evt(size_t_vec stv) { +float evt(float_vec stv) { return stv.x; } -extern "C" const size_t_vec stv; -size_t evt2() { +extern "C" const float_vec stv; +float evt2() { return stv.x; } -// CHECK: func.func @_Z3evtDv3_i(%arg0: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} -// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref -// CHECK-NEXT: return %0 : i32 -// CHECK-NEXT: } -// CHECK: func.func @_Z4evt2v() -> i32 attributes {llvm.linkage = #llvm.linkage} { -// CHECK-NEXT: %0 = memref.get_global @stv : memref<3xi32> -// CHECK-NEXT: %1 = affine.load %0[0] : memref<3xi32> -// CHECK-NEXT: return %1 : i32 -// CHECK-NEXT: } +// CHECK: memref.global @stv : memref<3xf32> +// CHECK: func.func @_Z3evtDv3_f(%arg0: memref) -> f32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref +// CHECK-NEXT: return %0 : f32 +// CHECK-NEXT: } +// CHECK: func.func @_Z4evt2v() -> f32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %0 = memref.get_global @stv : memref<3xf32> +// CHECK-NEXT: %1 = affine.load %0[0] : memref<3xf32> +// CHECK-NEXT: return %1 : f32 +// CHECK-NEXT: } + +// CHECK2: llvm.mlir.global external @stv() {addr_space = 0 : i32} : !llvm.array<3 x f32> +// CHECK2: func.func @_Z3evtDv3_f(%arg0: !llvm.array<3 x f32>) -> f32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK2-NEXT: %c1_i64 = arith.constant 1 : i64 +// CHECK2-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.array<3 x f32> : (i64) -> !llvm.ptr> +// CHECK2-NEXT: llvm.store %arg0, %0 : !llvm.ptr> +// CHECK2-NEXT: %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr>) -> !llvm.ptr +// CHECK2-NEXT: %2 = llvm.load %1 : !llvm.ptr +// CHECK2-NEXT: return %2 : f32 +// CHECK2-NEXT: } +// CHECK2: func.func @_Z4evt2v() -> f32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK2-NEXT: %0 = llvm.mlir.addressof @stv : !llvm.ptr> +// CHECK2-NEXT: %1 = llvm.getelementptr %0[0, 0] : (!llvm.ptr>) -> !llvm.ptr +// CHECK2-NEXT: %2 = llvm.load %1 : !llvm.ptr +// CHECK2-NEXT: return %2 : f32 +// CHECK2-NEXT: } +