diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index d7b0f819e9da..fb9e390a0f7a 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -19,6 +19,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/ArrayRef.h" @@ -45,6 +46,25 @@ protected: } }; +/// FIR conversion pattern template +template +class FIROpAndTypeConversion : public FIROpConversion { +public: + using FIROpConversion::FIROpConversion; + using OpAdaptor = typename FromOp::Adaptor; + + mlir::LogicalResult + matchAndRewrite(FromOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const final { + mlir::Type ty = this->convertType(op.getType()); + return doRewrite(op, ty, adaptor, rewriter); + } + + virtual mlir::LogicalResult + doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const = 0; +}; + // Lower `fir.address_of` operation to `llvm.address_of` operation. struct AddrOfOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -204,6 +224,82 @@ struct ZeroOpConversion : public FIROpConversion { } }; +/// InsertOnRange inserts a value into a sequence over a range of offsets. +struct InsertOnRangeOpConversion + : public FIROpAndTypeConversion { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + // Increments an array of subscripts in a row major fasion. + void incrementSubscripts(const SmallVector &dims, + SmallVector &subscripts) const { + for (size_t i = dims.size(); i > 0; --i) { + if (++subscripts[i - 1] < dims[i - 1]) { + return; + } + subscripts[i - 1] = 0; + } + } + + mlir::LogicalResult + doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + llvm::SmallVector dims; + auto type = adaptor.getOperands()[0].getType(); + + // Iteratively extract the array dimensions from the type. + while (auto t = type.dyn_cast()) { + dims.push_back(t.getNumElements()); + type = t.getElementType(); + } + + SmallVector lBounds; + SmallVector uBounds; + + // Extract integer value from the attribute + SmallVector coordinates = llvm::to_vector<4>( + llvm::map_range(range.coor(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); + + // Unzip the upper and lower bound and convert to a row major format. + for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) { + uBounds.push_back(*i++); + lBounds.push_back(*i); + } + + auto &subscripts = lBounds; + auto loc = range.getLoc(); + mlir::Value lastOp = adaptor.getOperands()[0]; + mlir::Value insertVal = adaptor.getOperands()[1]; + + auto i64Ty = rewriter.getI64Type(); + while (subscripts != uBounds) { + // Convert uint64_t's to Attribute's. + SmallVector subscriptAttrs; + for (const auto &subscript : subscripts) + subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript)); + lastOp = rewriter.create( + loc, ty, lastOp, insertVal, + ArrayAttr::get(range.getContext(), subscriptAttrs)); + + incrementSubscripts(dims, subscripts); + } + + // Convert uint64_t's to Attribute's. + SmallVector subscriptAttrs; + for (const auto &subscript : subscripts) + subscriptAttrs.push_back( + IntegerAttr::get(rewriter.getI64Type(), subscript)); + mlir::ArrayRef arrayRef(subscriptAttrs); + + rewriter.replaceOpWithNewOp( + range, ty, lastOp, insertVal, + ArrayAttr::get(range.getContext(), arrayRef)); + + return success(); + } +}; } // namespace namespace { @@ -221,10 +317,9 @@ public: auto *context = getModule().getContext(); fir::LLVMTypeConverter typeConverter{getModule()}; mlir::OwningRewritePatternList pattern(context); - pattern - .insert( - typeConverter); + pattern.insert(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir index 3ef31ac28818..a977dac869ea 100644 --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -84,6 +84,28 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> { // ----- +// Test global with insert_on_range operation not covering the full array +// in initializer region. + +fir.global internal @_QEmultiarray : !fir.array<32xi32> { + %c0_i32 = arith.constant 1 : i32 + %0 = fir.undefined !fir.array<32xi32> + %2 = fir.insert_on_range %0, %c0_i32, [5 : index, 31 : index] : (!fir.array<32xi32>, i32) -> !fir.array<32xi32> + fir.has_value %2 : !fir.array<32xi32> +} + +// CHECK: llvm.mlir.global internal @_QEmultiarray() : !llvm.array<32 x i32> { +// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.array<32 x i32> +// CHECK: %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[5] : !llvm.array<32 x i32> +// CHECK-COUNT-24: %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[{{.*}}] : !llvm.array<32 x i32> +// CHECK: %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[31] : !llvm.array<32 x i32> +// CHECK-NOT: llvm.insertvalue +// CHECK: llvm.return %{{.*}} : !llvm.array<32 x i32> +// CHECK: } + +// ----- + // Test fir.zero_bits operation with LLVM ptr type func @zero_test_ptr() {