diff --git a/tools/mlir-clang/Lib/clang-mlir.cc b/tools/mlir-clang/Lib/clang-mlir.cc index 5377e9d..cbddbed 100644 --- a/tools/mlir-clang/Lib/clang-mlir.cc +++ b/tools/mlir-clang/Lib/clang-mlir.cc @@ -859,10 +859,10 @@ MLIRScanner::VisitCXXBindTemporaryExpr(clang::CXXBindTemporaryExpr *expr) { ValueCategory MLIRScanner::VisitLambdaExpr(clang::LambdaExpr *expr) { - llvm::DenseMap InnerCaptures; - FieldDecl *ThisCapture = nullptr; + // llvm::DenseMap InnerCaptures; + // FieldDecl *ThisCapture = nullptr; - expr->getLambdaClass()->getCaptureFields(InnerCaptures, ThisCapture); + // expr->getLambdaClass()->getCaptureFields(InnerCaptures, ThisCapture); bool LLVMABI = false; mlir::Type t = Glob.getMLIRType(expr->getCallOperator()->getThisType()); @@ -884,20 +884,24 @@ ValueCategory MLIRScanner::VisitLambdaExpr(clang::LambdaExpr *expr) { } auto op = createAllocOp(t, nullptr, /*memtype*/ 0, isArray, LLVMABI); - llvm::DenseMap InnerCaptureKinds; - for (auto C : expr->getLambdaClass()->captures()) { - if (C.capturesVariable()) { - InnerCaptureKinds[C.getCapturedVar()] = C.getCaptureKind(); - } - } + for (auto tup : llvm::zip(expr->getLambdaClass()->captures(), + expr->getLambdaClass()->fields())) { + auto C = std::get<0>(tup); + auto field = std::get<1>(tup); + if (C.capturesThis()) + continue; + else if (!C.capturesVariable()) + continue; + + auto CK = C.getCaptureKind(); + auto var = C.getCapturedVar(); - for (auto pair : InnerCaptures) { ValueCategory result; - if (params.find(pair.first) != params.end()) { - result = params[pair.first]; + if (params.find(var) != params.end()) { + result = params[var]; } else { - if (auto VD = dyn_cast(pair.first)) { + if (auto VD = dyn_cast(var)) { if (Captures.find(VD) != Captures.end()) { FieldDecl *field = Captures[VD]; result = CommonFieldLookup( @@ -916,20 +920,19 @@ ValueCategory MLIRScanner::VisitLambdaExpr(clang::LambdaExpr *expr) { for (auto p : params) p.first->dump(); llvm::errs() << ""; - pair.first->dump(); + var->dump(); } endp: - assert(InnerCaptureKinds.find(pair.first) != InnerCaptureKinds.end()); bool isArray = false; - Glob.getMLIRType(pair.second->getType(), &isArray); + Glob.getMLIRType(field->getType(), &isArray); - if (InnerCaptureKinds[pair.first] == LambdaCaptureKind::LCK_ByCopy) - CommonFieldLookup(expr->getCallOperator()->getThisObjectType(), - pair.second, op, /*isLValue*/ false) + if (CK == LambdaCaptureKind::LCK_ByCopy) + CommonFieldLookup(expr->getCallOperator()->getThisObjectType(), field, op, + /*isLValue*/ false) .store(builder, result, isArray); else { - assert(InnerCaptureKinds[pair.first] == LambdaCaptureKind::LCK_ByRef); + assert(CK == LambdaCaptureKind::LCK_ByRef); assert(result.isReference); auto val = result.val; @@ -944,8 +947,8 @@ ValueCategory MLIRScanner::VisitLambdaExpr(clang::LambdaExpr *expr) { val); } - CommonFieldLookup(expr->getCallOperator()->getThisObjectType(), - pair.second, op, /*isLValue*/ false) + CommonFieldLookup(expr->getCallOperator()->getThisObjectType(), field, op, + /*isLValue*/ false) .store(builder, val); } } @@ -1260,11 +1263,10 @@ ValueCategory MLIRScanner::CommonArrayLookup(ValueCategory array, auto mt = dref.val.getType().cast(); auto shape = std::vector(mt.getShape()); - if (shape.size() > 1) { - // if (shape.size() > 2 || (shape.size() > 1 && !isImplicitRefResult)) { - shape.erase(shape.begin()); - } else { + if (shape.size() == 1 || shape.size() == 2 && isImplicitRefResult) { shape[0] = -1; + } else { + shape.erase(shape.begin()); } auto mt0 = mlir::MemRefType::get(shape, mt.getElementType(), @@ -1297,7 +1299,8 @@ MLIRScanner::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *expr) { getConstantIndex(0)); } bool isArray = false; - Glob.getMLIRType(expr->getType(), &isArray); + if (!Glob.CGM.getContext().getAsArrayType(expr->getType())) + Glob.getMLIRType(expr->getType(), &isArray); return CommonArrayLookup(moo, idx, isArray); } @@ -3881,9 +3884,8 @@ ValueCategory MLIRScanner::CommonFieldLookup(clang::QualType CT, if (rd->isUnion() || (CXRD && (!CXRD->hasDefinition() || CXRD->isPolymorphic() || CXRD->getDefinition()->getNumBases() > 0)) || - recursive || - (!ST->isLiteral() && (ST->getName().contains("SmallVector") || - ST->getName() == "struct._IO_FILE" || + recursive || + (!ST->isLiteral() && (ST->getName() == "struct._IO_FILE" || ST->getName() == "class.std::basic_ifstream" || ST->getName() == "class.std::basic_istream" || ST->getName() == "class.std::basic_ostream" || @@ -4149,8 +4151,7 @@ ValueCategory MLIRScanner::VisitCastExpr(CastExpr *E) { } assert(se.val); if (auto opt = se.val.getType().dyn_cast()) { - auto pt = Glob.typeTranslator.translateType( - anonymize(getLLVMType(E->getType()))); + auto pt = getMLIRType(E->getType()).cast(); if (se.isReference) pt = mlir::LLVM::LLVMPointerType::get(pt, opt.getAddressSpace()); auto nval = builder.create(loc, pt, se.val); @@ -4258,10 +4259,13 @@ ValueCategory MLIRScanner::VisitCastExpr(CastExpr *E) { } auto scalar = se.getValue(builder); if (auto spt = scalar.getType().dyn_cast()) { - LLVM::LLVMPointerType pt = - Glob.typeTranslator - .translateType(anonymize(getLLVMType(E->getType()))) - .cast(); + auto nt = getMLIRType(E->getType()); + LLVM::LLVMPointerType pt = nt.dyn_cast(); + if (!pt) { + return ValueCategory( + builder.create(loc, nt, scalar), + false); + } pt = LLVM::LLVMPointerType::get(pt.getElementType(), spt.getAddressSpace()); auto nval = builder.create(loc, pt, scalar); @@ -5428,9 +5432,8 @@ mlir::Type MLIRASTConsumer::getMLIRType(clang::QualType qt, bool *implicitRef, if (RT->getDecl()->isUnion() || (CXRD && (!CXRD->hasDefinition() || CXRD->isPolymorphic() || CXRD->getDefinition()->getNumBases() > 0)) || - recursive || - (!ST->isLiteral() && (ST->getName().contains("SmallVector") || - ST->getName() == "struct._IO_FILE" || + recursive || + (!ST->isLiteral() && (ST->getName() == "struct._IO_FILE" || ST->getName() == "class.std::basic_ifstream" || ST->getName() == "class.std::basic_istream" || ST->getName() == "class.std::basic_ostream" || @@ -5650,8 +5653,7 @@ mlir::Type MLIRASTConsumer::getMLIRType(llvm::Type *t) { if (auto ST = dyn_cast(pt->getElementType())) { if (ST->getNumElements() == 0 || (!ST->isLiteral() && - (ST->getName().contains("SmallVector") || - ST->getName() == "struct._IO_FILE" || + (ST->getName() == "struct._IO_FILE" || ST->getName() == "class.std::basic_ifstream" || ST->getName() == "class.std::basic_istream" || ST->getName() == "class.std::basic_ostream" || @@ -5726,8 +5728,7 @@ mlir::Type MLIRASTConsumer::getMLIRType(llvm::Type *t) { if (!recursive && ST->getNumElements() == 1) return getMLIRType(ST->getTypeAtIndex(0U)); if (ST->getNumElements() == 0 || recursive || - (!ST->isLiteral() && (ST->getName().contains("SmallVector") || - ST->getName() == "struct._IO_FILE" || + (!ST->isLiteral() && (ST->getName() == "struct._IO_FILE" || ST->getName() == "class.std::basic_ifstream" || ST->getName() == "class.std::basic_istream" || ST->getName() == "class.std::basic_ostream" || diff --git a/tools/mlir-clang/Test/Verification/caff.cpp b/tools/mlir-clang/Test/Verification/caff.cpp new file mode 100644 index 0000000..77577f5 --- /dev/null +++ b/tools/mlir-clang/Test/Verification/caff.cpp @@ -0,0 +1,51 @@ +// RUN: mlir-clang %s --function=* -S | FileCheck %s + +struct AOperandInfo { + void* data; + + bool is_output; + + bool is_read_write; +}; + + +/// This is all the non-templated stuff common to all SmallVectors. + +/// This is the part of SmallVectorTemplateBase which does not depend on whether +/// the type T is a POD. The extra dummy template argument is used by ArrayRef +/// to avoid unnecessarily requiring T to be complete. +template +class ASmallVectorTemplateCommon { + public: + void *BeginX, *EndX; + + // forward iterator creation methods. + const T* begin() const { + return (const T*)this->BeginX; + } +}; + +unsigned long long int div_kernel_cuda(ASmallVectorTemplateCommon &operands) { + return (const AOperandInfo*)operands.EndX - operands.begin(); +} + + +// CHECK: func @_Z15div_kernel_cudaR26ASmallVectorTemplateCommonI12AOperandInfoE(%arg0: !llvm.ptr, ptr)>>) -> i64 attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c1_i32] : (!llvm.ptr, ptr)>>, i32, i32) -> !llvm.ptr> +// CHECK-NEXT: %1 = llvm.load %0 : !llvm.ptr> +// CHECK-NEXT: %2 = llvm.bitcast %1 : !llvm.ptr to !llvm.ptr, i8, i8)>> +// CHECK-NEXT: %3 = call @_ZNK26ASmallVectorTemplateCommonI12AOperandInfoE5beginEv(%arg0) : (!llvm.ptr, ptr)>>) -> !llvm.ptr, i8, i8)>> +// CHECK-NEXT: %4 = llvm.ptrtoint %3 : !llvm.ptr, i8, i8)>> to i64 +// CHECK-NEXT: %5 = llvm.ptrtoint %2 : !llvm.ptr, i8, i8)>> to i64 +// CHECK-NEXT: %6 = arith.subi %5, %4 : i64 +// CHECK-NEXT: return %6 : i64 +// CHECK-NEXT: } +// CHECK: func @_ZNK26ASmallVectorTemplateCommonI12AOperandInfoE5beginEv(%arg0: !llvm.ptr, ptr)>>) -> !llvm.ptr, i8, i8)>> attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr, ptr)>>, i32, i32) -> !llvm.ptr> +// CHECK-NEXT: %1 = llvm.load %0 : !llvm.ptr> +// CHECK-NEXT: %2 = llvm.bitcast %1 : !llvm.ptr to !llvm.ptr, i8, i8)>> +// CHECK-NEXT: return %2 : !llvm.ptr, i8, i8)>> +// CHECK-NEXT: } diff --git a/tools/mlir-clang/Test/Verification/capture.cpp b/tools/mlir-clang/Test/Verification/capture.cpp index 318c309..0b3a280 100644 --- a/tools/mlir-clang/Test/Verification/capture.cpp +++ b/tools/mlir-clang/Test/Verification/capture.cpp @@ -19,11 +19,11 @@ double kernel_deriche(int x, float y) { // CHECK-NEXT: %1 = llvm.alloca %c1_i64 x !llvm.struct<(memref, i32)> : (i64) -> !llvm.ptr, i32)>> // CHECK-NEXT: %2 = memref.alloca() : memref<1xf32> // CHECK-NEXT: affine.store %arg1, %2[0] : memref<1xf32> -// CHECK-NEXT: %3 = llvm.getelementptr %1[%c0_i32, %c1_i32] : (!llvm.ptr, i32)>>, i32, i32) -> !llvm.ptr -// CHECK-NEXT: llvm.store %arg0, %3 : !llvm.ptr -// CHECK-NEXT: %4 = memref.cast %2 : memref<1xf32> to memref -// CHECK-NEXT: %5 = llvm.getelementptr %1[%c0_i32, %c0_i32] : (!llvm.ptr, i32)>>, i32, i32) -> !llvm.ptr> -// CHECK-NEXT: llvm.store %4, %5 : !llvm.ptr> +// CHECK-NEXT: %3 = memref.cast %2 : memref<1xf32> to memref +// CHECK-NEXT: %4 = llvm.getelementptr %1[%c0_i32, %c0_i32] : (!llvm.ptr, i32)>>, i32, i32) -> !llvm.ptr> +// CHECK-NEXT: llvm.store %3, %4 : !llvm.ptr> +// CHECK-NEXT: %5 = llvm.getelementptr %1[%c0_i32, %c1_i32] : (!llvm.ptr, i32)>>, i32, i32) -> !llvm.ptr +// CHECK-NEXT: llvm.store %arg0, %5 : !llvm.ptr // CHECK-NEXT: %6 = llvm.load %1 : !llvm.ptr, i32)>> // CHECK-NEXT: llvm.store %6, %0 : !llvm.ptr, i32)>> // CHECK-NEXT: call @_ZZ14kernel_dericheENK3$_0clEv(%0) : (!llvm.ptr, i32)>>) -> () diff --git a/tools/mlir-clang/Test/Verification/ident.cpp b/tools/mlir-clang/Test/Verification/ident.cpp index add2920..7cce76a 100644 --- a/tools/mlir-clang/Test/Verification/ident.cpp +++ b/tools/mlir-clang/Test/Verification/ident.cpp @@ -41,55 +41,54 @@ void lt_kernel_cuda(MTensorIterator& iter) { } } -// CHECK: func @lt_kernel_cuda(%arg0: !llvm.ptr>)>)>>) attributes {llvm.linkage = #llvm.linkage} { -// CHECK-DAG: %c0_i32 = arith.constant 0 : i32 -// CHECK-DAG: %c1_i64 = arith.constant 1 : i64 -// CHECK-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.struct<(ptr>)>)>>)> : (i64) -> !llvm.ptr>)>)>>)>> -// CHECK-NEXT: %1 = llvm.alloca %c1_i64 x !llvm.struct<(ptr>)>)>>)> : (i64) -> !llvm.ptr>)>)>>)>> -// CHECK-NEXT: %2 = call @_ZNK15MTensorIterator11input_dtypeEv(%arg0) : (!llvm.ptr>)>)>>) -> i8 +// CHECK: func @lt_kernel_cuda(%arg0: !llvm.ptr)>)>>) attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c1_i64 = arith.constant 1 : i64 +// CHECK-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.struct<(!llvm.ptr)>)>>)> : (i64) -> !llvm.ptr)>)>>)>> +// CHECK-NEXT: %1 = llvm.alloca %c1_i64 x !llvm.struct<(!llvm.ptr)>)>>)> : (i64) -> !llvm.ptr)>)>>)>> +// CHECK-NEXT: %2 = call @_ZNK15MTensorIterator11input_dtypeEv(%arg0) : (!llvm.ptr)>)>>) -> i8 // CHECK-NEXT: %3 = arith.trunci %2 : i8 to i1 // CHECK-NEXT: scf.if %3 { -// CHECK-NEXT: %4 = llvm.getelementptr %1[%c0_i32, %c0_i32] : (!llvm.ptr>)>)>>)>>, i32, i32) -> !llvm.ptr>)>)>>> -// CHECK-NEXT: llvm.store %arg0, %4 : !llvm.ptr>)>)>>> -// CHECK-NEXT: %5 = llvm.load %1 : !llvm.ptr>)>)>>)>> -// CHECK-NEXT: llvm.store %5, %0 : !llvm.ptr>)>)>>)>> -// CHECK-NEXT: call @_ZZ14lt_kernel_cudaENK3$_0clEv(%0) : (!llvm.ptr>)>)>>)>>) -> () +// CHECK-NEXT: %4 = llvm.getelementptr %1[%c0_i32, %c0_i32] : (!llvm.ptr)>)>>)>>, i32, i32) -> !llvm.ptr)>)>>> +// CHECK-NEXT: llvm.store %arg0, %4 : !llvm.ptr)>)>>> +// CHECK-NEXT: %5 = llvm.load %1 : !llvm.ptr)>)>>)>> +// CHECK-NEXT: llvm.store %5, %0 : !llvm.ptr)>)>>)>> +// CHECK-NEXT: call @_ZZ14lt_kernel_cudaENK3$_0clEv(%0) : (!llvm.ptr)>)>>)>>) -> () // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } -// CHECK-NEXT: func @_ZNK15MTensorIterator11input_dtypeEv(%arg0: !llvm.ptr>)>)>>) -> i8 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: func @_ZNK15MTensorIterator11input_dtypeEv(%arg0: !llvm.ptr)>)>>) -> i8 attributes {llvm.linkage = #llvm.linkage} { // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr>)>)>>, i32, i32) -> !llvm.ptr>)>> -// CHECK-NEXT: %1 = call @_ZNK12MSmallVectorI12MOperandInfoEixEi(%0, %c0_i32) : (!llvm.ptr>)>>, i32) -> memref -// CHECK-NEXT: %2 = affine.load %1[0, 1] : memref -// CHECK-NEXT: return %2 : i8 +// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr)>)>>, i32, i32) -> !llvm.ptr)>> +// CHECK-NEXT: %1 = "polygeist.pointer2memref"(%0) : (!llvm.ptr)>>) -> memref> +// CHECK-NEXT: %2 = call @_ZNK12MSmallVectorI12MOperandInfoEixEi(%1, %c0_i32) : (memref>, i32) -> memref +// CHECK-NEXT: %3 = affine.load %2[0, 1] : memref +// CHECK-NEXT: return %3 : i8 // CHECK-NEXT: } -// CHECK-NEXT: func private @_ZZ14lt_kernel_cudaENK3$_0clEv(%arg0: !llvm.ptr>)>)>>)>>) attributes {llvm.linkage = #llvm.linkage} { -// CHECK-DAG: %c0_i32 = arith.constant 0 : i32 -// CHECK-DAG: %c1_i64 = arith.constant 1 : i64 +// CHECK: func private @_ZZ14lt_kernel_cudaENK3$_0clEv(%arg0: !llvm.ptr)>)>>)>>) attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %c1_i64 = arith.constant 1 : i64 +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 // CHECK-NEXT: %0 = llvm.alloca %c1_i64 x !llvm.struct<(i8)> : (i64) -> !llvm.ptr> // CHECK-NEXT: %1 = llvm.alloca %c1_i64 x !llvm.struct<(i8)> : (i64) -> !llvm.ptr> -// CHECK-NEXT: %2 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr>)>)>>)>>, i32, i32) -> !llvm.ptr>)>)>>> -// CHECK-NEXT: %3 = llvm.load %2 : !llvm.ptr>)>)>>> +// CHECK-NEXT: %2 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr)>)>>)>>, i32, i32) -> !llvm.ptr)>)>>> +// CHECK-NEXT: %3 = llvm.load %2 : !llvm.ptr)>)>>> // CHECK-NEXT: %4 = llvm.load %1 : !llvm.ptr> // CHECK-NEXT: llvm.store %4, %0 : !llvm.ptr> -// CHECK-NEXT: %5 = call @_ZNK15MTensorIterator6deviceEv(%3) : (!llvm.ptr>)>)>>) -> i8 +// CHECK-NEXT: %5 = call @_ZNK15MTensorIterator6deviceEv(%3) : (!llvm.ptr)>)>>) -> i8 // CHECK-NEXT: return // CHECK-NEXT: } -// CHECK-NEXT: func @_ZNK12MSmallVectorI12MOperandInfoEixEi(%arg0: !llvm.ptr>)>>, %arg1: i32) -> memref attributes {llvm.linkage = #llvm.linkage} { -// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr>)>>, i32, i32) -> !llvm.ptr>> -// CHECK-NEXT: %1 = llvm.load %0 : !llvm.ptr>> -// CHECK-NEXT: %2 = arith.index_cast %arg1 : i32 to index -// CHECK-NEXT: %3 = arith.index_cast %2 : index to i64 -// CHECK-NEXT: %4 = llvm.getelementptr %1[%3] : (!llvm.ptr>, i64) -> !llvm.ptr> -// CHECK-NEXT: %5 = "polygeist.pointer2memref"(%4) : (!llvm.ptr>) -> memref -// CHECK-NEXT: return %5 : memref +// CHECK: func @_ZNK12MSmallVectorI12MOperandInfoEixEi(%arg0: memref>, %arg1: i32) -> memref attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %0 = affine.load %arg0[0, 0] : memref> +// CHECK-NEXT: %1 = arith.index_cast %arg1 : i32 to index +// CHECK-NEXT: %2 = "polygeist.subindex"(%0, %1) : (memref, index) -> memref +// CHECK-NEXT: return %2 : memref // CHECK-NEXT: } -// CHECK-NEXT: func @_ZNK15MTensorIterator6deviceEv(%arg0: !llvm.ptr>)>)>>) -> i8 attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: func @_ZNK15MTensorIterator6deviceEv(%arg0: !llvm.ptr)>)>>) -> i8 attributes {llvm.linkage = #llvm.linkage} { // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr>)>)>>, i32, i32) -> !llvm.ptr>)>> -// CHECK-NEXT: %1 = call @_ZNK12MSmallVectorI12MOperandInfoEixEi(%0, %c0_i32) : (!llvm.ptr>)>>, i32) -> memref -// CHECK-NEXT: %2 = affine.load %1[0, 0] : memref -// CHECK-NEXT: return %2 : i8 +// CHECK-NEXT: %0 = llvm.getelementptr %arg0[%c0_i32, %c0_i32] : (!llvm.ptr)>)>>, i32, i32) -> !llvm.ptr)>> +// CHECK-NEXT: %1 = "polygeist.pointer2memref"(%0) : (!llvm.ptr)>>) -> memref> +// CHECK-NEXT: %2 = call @_ZNK12MSmallVectorI12MOperandInfoEixEi(%1, %c0_i32) : (memref>, i32) -> memref +// CHECK-NEXT: %3 = affine.load %2[0, 0] : memref +// CHECK-NEXT: return %3 : i8 // CHECK-NEXT: } + diff --git a/tools/mlir-clang/Test/Verification/ident2.cpp b/tools/mlir-clang/Test/Verification/ident2.cpp index b5a6617..a0ef444 100644 --- a/tools/mlir-clang/Test/Verification/ident2.cpp +++ b/tools/mlir-clang/Test/Verification/ident2.cpp @@ -1,7 +1,5 @@ // RUN: mlir-clang %s --function=* -S | FileCheck %s -// XFAIL: * - struct MOperandInfo { char device; char dtype; @@ -13,3 +11,8 @@ struct MOperandInfo& inner() { return begin()[0]; } +// CHECK: func @_Z5innerv() -> memref attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %0 = call @_Z5beginv() : () -> memref +// CHECK-NEXT: return %0 : memref +// CHECK-NEXT: } +// CHECK-NEXT: func private @_Z5beginv() -> memref attributes {llvm.linkage = #llvm.linkage} diff --git a/tools/mlir-clang/mlir-clang.cc b/tools/mlir-clang/mlir-clang.cc index a821564..f255e52 100644 --- a/tools/mlir-clang/mlir-clang.cc +++ b/tools/mlir-clang/mlir-clang.cc @@ -390,6 +390,10 @@ int main(int argc, char **argv) { MemRefType::attachInterface>(context); LLVM::LLVMStructType::attachInterface>( context); + LLVM::LLVMPointerType::attachInterface< + PtrElementModel>(context); + LLVM::LLVMArrayType::attachInterface>( + context); if (showDialects) { outs() << "Registered Dialects:\n";