From 48f77a9865c054c2318eb039d4544d3d5f0082e8 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 9 Jan 2021 17:33:38 -0800 Subject: [PATCH] [RTL] Canonicalize concat(zext(op), ...stuff) -> concat(0, op, ...stuff) --- lib/Dialect/RTL/RTLOps.cpp | 62 ++++++++++++++----- .../ExportVerilog/ExportVerilog.cpp | 3 +- test/Dialect/RTL/canonicalization.mlir | 12 ++++ 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/lib/Dialect/RTL/RTLOps.cpp b/lib/Dialect/RTL/RTLOps.cpp index 355988f5fc..02c68d35f5 100644 --- a/lib/Dialect/RTL/RTLOps.cpp +++ b/lib/Dialect/RTL/RTLOps.cpp @@ -691,14 +691,13 @@ void ConstantOp::build(OpBuilder &builder, OperationState &result, /// /// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true /// -/// If allowDuplicatingOp is true, then the 'hasOneUse' check is skipped. template -static bool tryFlatteningOperands(Op op, PatternRewriter &rewriter, - bool allowDuplicatingOp = false) { +static bool tryFlatteningOperands(Op op, PatternRewriter &rewriter) { auto inputs = op.inputs(); for (size_t i = 0, size = inputs.size(); i != size; ++i) { - if (!allowDuplicatingOp && !inputs[i].hasOneUse()) + // Don't duplicate logic. + if (!inputs[i].hasOneUse()) continue; auto flattenOp = inputs[i].template getDefiningOp(); if (!flattenOp) @@ -1298,24 +1297,55 @@ void ConcatOp::build(OpBuilder &builder, OperationState &result, build(builder, result, builder.getIntegerType(resultWidth), inputs); } +static LogicalResult tryCanonicalizeConcat(ConcatOp op, + PatternRewriter &rewriter) { + auto inputs = op.inputs(); + auto size = inputs.size(); + assert(size > 1 && "expected 2 or more operands"); + + // This function is used when we flatten neighboring operands of a (variadic) + // concat into a new vesion of the concat. first/last indices are inclusive. + auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex, + ValueRange replacements) -> LogicalResult { + SmallVector newOperands; + newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex); + newOperands.append(replacements.begin(), replacements.end()); + newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end()); + rewriter.replaceOpWithNewOp(op, op.getType(), newOperands); + return success(); + }; + + for (size_t i = 0; i != size; ++i) { + // If an operand to the concat is itself a concat, then we can fold them + // together. + if (auto subConcat = inputs[i].getDefiningOp()) + return flattenConcat(i, i, subConcat->getOperands()); + + // We can flatten a zext into the concat, since a zext is a 0 plus the input + // value. + if (auto zext = inputs[i].getDefiningOp()) { + unsigned zeroWidth = zext.getType().getIntOrFloatBitWidth() - + zext.input().getType().getIntOrFloatBitWidth(); + Value replacement[2] = { + rewriter.create(op.getLoc(), APInt(zeroWidth, 0)), + zext.input()}; + + return flattenConcat(i, i, replacement); + } + } + + /// TODO: Sequences of constants: concat(..., c1, c2) -> concat(..., c3). + /// TODO: Sequences of neighboring extracts. + return failure(); +} + void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { struct Folder final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcatOp op, PatternRewriter &rewriter) const override { - auto inputs = op.inputs(); - auto size = inputs.size(); - assert(size > 1 && "expected 2 or more operands"); - - // concat(x, concat(...)) -> concat(x, ...) -- flatten - if (tryFlatteningOperands(op, rewriter, /*allowDuplicatingOp=*/true)) - return success(); - - /// TODO: Sequences of constants: concat(..., c1, c2) -> concat(..., c3). - /// TODO: Sequences of neighboring extracts. - /// TODO: zext argument into 0 + value. - return failure(); + return tryCanonicalizeConcat(op, rewriter); } }; results.insert(context); diff --git a/lib/Translation/ExportVerilog/ExportVerilog.cpp b/lib/Translation/ExportVerilog/ExportVerilog.cpp index 7b5feb18cc..6a416f4a6b 100644 --- a/lib/Translation/ExportVerilog/ExportVerilog.cpp +++ b/lib/Translation/ExportVerilog/ExportVerilog.cpp @@ -1963,4 +1963,5 @@ void circt::registerToVerilogTranslation() { "emit-verilog", exportVerilog, [](DialectRegistry ®istry) { registry.insert(); }); -} + +} \ No newline at end of file diff --git a/test/Dialect/RTL/canonicalization.mlir b/test/Dialect/RTL/canonicalization.mlir index 47483aaebc..0b20cb2045 100644 --- a/test/Dialect/RTL/canonicalization.mlir +++ b/test/Dialect/RTL/canonicalization.mlir @@ -538,3 +538,15 @@ func @concat_fold_1(%arg0: i4, %arg1: i3, %arg2: i1) -> i8 { %b = rtl.concat %a, %arg2 : (i7, i1) -> i8 return %b : i8 } + +// CHECK-LABEL: func @concat_fold_2 +func @concat_fold_2(%arg0: i4, %arg1: i3, %arg2: i1) -> i16 { + // Zext should get flattened into the concat + // CHECK-NEXT: %c0_i3 = rtl.constant(0 : i3) : i3 + %a = rtl.zext %arg0 : (i4) -> i7 + // CHECK-NEXT: %0 = rtl.sext + %b = rtl.sext %arg1 : (i3) -> i8 + // CHECK-NEXT: = rtl.concat %c0_i3, %arg0, %0, %arg2 + %c = rtl.concat %a, %b, %arg2 : (i7, i8, i1) -> i16 + return %c : i16 +}