[NFC] Convert a canonicalizer to a pattern

This commit is contained in:
Andrew Lenharth 2024-05-17 13:34:40 -05:00
parent 5e2ad89fda
commit a1a3b6ca78
2 changed files with 11 additions and 29 deletions

View File

@ -120,6 +120,9 @@ def AllOneConstantOp : Constraint<CPred<"$0.getDefiningOp<ConstantOp>() && $0.ge
def TypeWidthAdjust32 : NativeCodeCall<
"$_builder.getI32IntegerAttr(type_cast<FIRRTLBaseType>($0.getType()).getBitWidthOrSentinel() - type_cast<FIRRTLBaseType>($1.getType()).getBitWidthOrSentinel())">;
// Int 1 is one higher than int 2
def AdjInt : Constraint<CPred<"$0.getValue() + 1 == $1.getValue()">>;
/// Drop the writer to the first argument and passthrough the second
def DropWrite : NativeCodeCall<"dropWrite($_builder, $0, $1)">;
@ -693,6 +696,12 @@ def CatCast : Pat <
(MoveNameHint $old, (CatPrimOp $x, $y)),
[(EqualSigns $x, $y)]>;
// cat(bits a:b x, bits b+1:c x) -> bits( a:c x)
def CatBitsBits : Pat <
(CatPrimOp:$old (BitsPrimOp $x, I32Attr:$hi, I32Attr:$mid), (BitsPrimOp $x, I32Attr:$mid2, I32Attr:$low)),
(MoveNameHint $old, (BitsPrimOp $x, $hi, $low)),
[(AdjInt $mid2, $mid)]>;
// regreset(clock, constant_zero, resetValue) -> reg(clock)
def RegResetWithZeroReset : Pat<
(RegResetOp $clock, $reset, $_, $name, $nameKind, $annotations, $inner_sym, $forceable),

View File

@ -1199,37 +1199,10 @@ void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.insert<patterns::DShrOfConstant>(context);
}
namespace {
// cat(bits(x, ...), bits(x, ...)) -> bits(x ...) when the two ...'s are
// consecutive in the input.
struct CatBitsBits : public mlir::RewritePattern {
CatBitsBits(MLIRContext *context)
: RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto cat = cast<CatPrimOp>(op);
if (auto lhsBits =
dyn_cast_or_null<BitsPrimOp>(cat.getLhs().getDefiningOp())) {
if (auto rhsBits =
dyn_cast_or_null<BitsPrimOp>(cat.getRhs().getDefiningOp())) {
if (lhsBits.getInput() == rhsBits.getInput() &&
lhsBits.getLo() - 1 == rhsBits.getHi()) {
replaceOpWithNewOpAndCopyName<BitsPrimOp>(
rewriter, cat, cat.getType(), lhsBits.getInput(), lhsBits.getHi(),
rhsBits.getLo());
return success();
}
}
}
return failure();
}
};
} // namespace
void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<CatBitsBits, patterns::CatDoubleConst, patterns::CatCast>(
context);
results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
patterns::CatCast>(context);
}
OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {