[RTL] Canonicalize concat(zext(op), ...stuff) -> concat(0, op, ...stuff)

This commit is contained in:
Chris Lattner 2021-01-09 17:33:38 -08:00
parent fa05e84822
commit 48f77a9865
3 changed files with 60 additions and 17 deletions

View File

@ -691,14 +691,13 @@ void ConstantOp::build(OpBuilder &builder, OperationState &result,
/// ///
/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true /// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
/// ///
/// If allowDuplicatingOp is true, then the 'hasOneUse' check is skipped.
template <typename Op> template <typename Op>
static bool tryFlatteningOperands(Op op, PatternRewriter &rewriter, static bool tryFlatteningOperands(Op op, PatternRewriter &rewriter) {
bool allowDuplicatingOp = false) {
auto inputs = op.inputs(); auto inputs = op.inputs();
for (size_t i = 0, size = inputs.size(); i != size; ++i) { for (size_t i = 0, size = inputs.size(); i != size; ++i) {
if (!allowDuplicatingOp && !inputs[i].hasOneUse()) // Don't duplicate logic.
if (!inputs[i].hasOneUse())
continue; continue;
auto flattenOp = inputs[i].template getDefiningOp<Op>(); auto flattenOp = inputs[i].template getDefiningOp<Op>();
if (!flattenOp) if (!flattenOp)
@ -1298,24 +1297,55 @@ void ConcatOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, builder.getIntegerType(resultWidth), inputs); build(builder, result, builder.getIntegerType(resultWidth), inputs);
} }
static LogicalResult tryCanonicalizeConcat(ConcatOp op,
PatternRewriter &rewriter) {
auto inputs = op.inputs();
auto size = inputs.size();
assert(size > 1 && "expected 2 or more operands");
// This function is used when we flatten neighboring operands of a (variadic)
// concat into a new vesion of the concat. first/last indices are inclusive.
auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
ValueRange replacements) -> LogicalResult {
SmallVector<Value, 4> newOperands;
newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
newOperands.append(replacements.begin(), replacements.end());
newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
rewriter.replaceOpWithNewOp<ConcatOp>(op, op.getType(), newOperands);
return success();
};
for (size_t i = 0; i != size; ++i) {
// If an operand to the concat is itself a concat, then we can fold them
// together.
if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
return flattenConcat(i, i, subConcat->getOperands());
// We can flatten a zext into the concat, since a zext is a 0 plus the input
// value.
if (auto zext = inputs[i].getDefiningOp<ZExtOp>()) {
unsigned zeroWidth = zext.getType().getIntOrFloatBitWidth() -
zext.input().getType().getIntOrFloatBitWidth();
Value replacement[2] = {
rewriter.create<ConstantOp>(op.getLoc(), APInt(zeroWidth, 0)),
zext.input()};
return flattenConcat(i, i, replacement);
}
}
/// TODO: Sequences of constants: concat(..., c1, c2) -> concat(..., c3).
/// TODO: Sequences of neighboring extracts.
return failure();
}
void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) { MLIRContext *context) {
struct Folder final : public OpRewritePattern<ConcatOp> { struct Folder final : public OpRewritePattern<ConcatOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatOp op, LogicalResult matchAndRewrite(ConcatOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto inputs = op.inputs(); return tryCanonicalizeConcat(op, rewriter);
auto size = inputs.size();
assert(size > 1 && "expected 2 or more operands");
// concat(x, concat(...)) -> concat(x, ...) -- flatten
if (tryFlatteningOperands(op, rewriter, /*allowDuplicatingOp=*/true))
return success();
/// TODO: Sequences of constants: concat(..., c1, c2) -> concat(..., c3).
/// TODO: Sequences of neighboring extracts.
/// TODO: zext argument into 0 + value.
return failure();
} }
}; };
results.insert<Folder>(context); results.insert<Folder>(context);

View File

@ -1963,4 +1963,5 @@ void circt::registerToVerilogTranslation() {
"emit-verilog", exportVerilog, [](DialectRegistry &registry) { "emit-verilog", exportVerilog, [](DialectRegistry &registry) {
registry.insert<RTLDialect, SVDialect>(); registry.insert<RTLDialect, SVDialect>();
}); });
}
}

View File

@ -538,3 +538,15 @@ func @concat_fold_1(%arg0: i4, %arg1: i3, %arg2: i1) -> i8 {
%b = rtl.concat %a, %arg2 : (i7, i1) -> i8 %b = rtl.concat %a, %arg2 : (i7, i1) -> i8
return %b : i8 return %b : i8
} }
// CHECK-LABEL: func @concat_fold_2
func @concat_fold_2(%arg0: i4, %arg1: i3, %arg2: i1) -> i16 {
// Zext should get flattened into the concat
// CHECK-NEXT: %c0_i3 = rtl.constant(0 : i3) : i3
%a = rtl.zext %arg0 : (i4) -> i7
// CHECK-NEXT: %0 = rtl.sext
%b = rtl.sext %arg1 : (i3) -> i8
// CHECK-NEXT: = rtl.concat %c0_i3, %arg0, %0, %arg2
%c = rtl.concat %a, %b, %arg2 : (i7, i8, i1) -> i16
return %c : i16
}