mirror of https://github.com/llvm/circt.git
[RTL] Canonicalize concat(zext(op), ...stuff) -> concat(0, op, ...stuff)
This commit is contained in:
parent
fa05e84822
commit
48f77a9865
|
@ -691,14 +691,13 @@ void ConstantOp::build(OpBuilder &builder, OperationState &result,
|
|||
///
|
||||
/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
|
||||
///
|
||||
/// If allowDuplicatingOp is true, then the 'hasOneUse' check is skipped.
|
||||
template <typename Op>
|
||||
static bool tryFlatteningOperands(Op op, PatternRewriter &rewriter,
|
||||
bool allowDuplicatingOp = false) {
|
||||
static bool tryFlatteningOperands(Op op, PatternRewriter &rewriter) {
|
||||
auto inputs = op.inputs();
|
||||
|
||||
for (size_t i = 0, size = inputs.size(); i != size; ++i) {
|
||||
if (!allowDuplicatingOp && !inputs[i].hasOneUse())
|
||||
// Don't duplicate logic.
|
||||
if (!inputs[i].hasOneUse())
|
||||
continue;
|
||||
auto flattenOp = inputs[i].template getDefiningOp<Op>();
|
||||
if (!flattenOp)
|
||||
|
@ -1298,24 +1297,55 @@ void ConcatOp::build(OpBuilder &builder, OperationState &result,
|
|||
build(builder, result, builder.getIntegerType(resultWidth), inputs);
|
||||
}
|
||||
|
||||
static LogicalResult tryCanonicalizeConcat(ConcatOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
auto inputs = op.inputs();
|
||||
auto size = inputs.size();
|
||||
assert(size > 1 && "expected 2 or more operands");
|
||||
|
||||
// This function is used when we flatten neighboring operands of a (variadic)
|
||||
// concat into a new vesion of the concat. first/last indices are inclusive.
|
||||
auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
|
||||
ValueRange replacements) -> LogicalResult {
|
||||
SmallVector<Value, 4> newOperands;
|
||||
newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
|
||||
newOperands.append(replacements.begin(), replacements.end());
|
||||
newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
|
||||
rewriter.replaceOpWithNewOp<ConcatOp>(op, op.getType(), newOperands);
|
||||
return success();
|
||||
};
|
||||
|
||||
for (size_t i = 0; i != size; ++i) {
|
||||
// If an operand to the concat is itself a concat, then we can fold them
|
||||
// together.
|
||||
if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
|
||||
return flattenConcat(i, i, subConcat->getOperands());
|
||||
|
||||
// We can flatten a zext into the concat, since a zext is a 0 plus the input
|
||||
// value.
|
||||
if (auto zext = inputs[i].getDefiningOp<ZExtOp>()) {
|
||||
unsigned zeroWidth = zext.getType().getIntOrFloatBitWidth() -
|
||||
zext.input().getType().getIntOrFloatBitWidth();
|
||||
Value replacement[2] = {
|
||||
rewriter.create<ConstantOp>(op.getLoc(), APInt(zeroWidth, 0)),
|
||||
zext.input()};
|
||||
|
||||
return flattenConcat(i, i, replacement);
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: Sequences of constants: concat(..., c1, c2) -> concat(..., c3).
|
||||
/// TODO: Sequences of neighboring extracts.
|
||||
return failure();
|
||||
}
|
||||
|
||||
void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
struct Folder final : public OpRewritePattern<ConcatOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ConcatOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto inputs = op.inputs();
|
||||
auto size = inputs.size();
|
||||
assert(size > 1 && "expected 2 or more operands");
|
||||
|
||||
// concat(x, concat(...)) -> concat(x, ...) -- flatten
|
||||
if (tryFlatteningOperands(op, rewriter, /*allowDuplicatingOp=*/true))
|
||||
return success();
|
||||
|
||||
/// TODO: Sequences of constants: concat(..., c1, c2) -> concat(..., c3).
|
||||
/// TODO: Sequences of neighboring extracts.
|
||||
/// TODO: zext argument into 0 + value.
|
||||
return failure();
|
||||
return tryCanonicalizeConcat(op, rewriter);
|
||||
}
|
||||
};
|
||||
results.insert<Folder>(context);
|
||||
|
|
|
@ -1963,4 +1963,5 @@ void circt::registerToVerilogTranslation() {
|
|||
"emit-verilog", exportVerilog, [](DialectRegistry ®istry) {
|
||||
registry.insert<RTLDialect, SVDialect>();
|
||||
});
|
||||
}
|
||||
|
||||
}
|
|
@ -538,3 +538,15 @@ func @concat_fold_1(%arg0: i4, %arg1: i3, %arg2: i1) -> i8 {
|
|||
%b = rtl.concat %a, %arg2 : (i7, i1) -> i8
|
||||
return %b : i8
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @concat_fold_2
|
||||
func @concat_fold_2(%arg0: i4, %arg1: i3, %arg2: i1) -> i16 {
|
||||
// Zext should get flattened into the concat
|
||||
// CHECK-NEXT: %c0_i3 = rtl.constant(0 : i3) : i3
|
||||
%a = rtl.zext %arg0 : (i4) -> i7
|
||||
// CHECK-NEXT: %0 = rtl.sext
|
||||
%b = rtl.sext %arg1 : (i3) -> i8
|
||||
// CHECK-NEXT: = rtl.concat %c0_i3, %arg0, %0, %arg2
|
||||
%c = rtl.concat %a, %b, %arg2 : (i7, i8, i1) -> i16
|
||||
return %c : i16
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue