[Comb] Merge ors of concats with constant operands. (#1701)

Fixes one of the patterns described in #1624, where the or'ing of two
values that only have bits set in non-overlapping positions can be
simplified to a single concat. It also addresses the more general
pattern of folding any constant operand within a concat with operands in
a sibling concat.
This commit is contained in:
Richard Xia 2021-09-10 13:28:42 -07:00 committed by GitHub
parent fb1ba5a36f
commit 826df736d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 173 additions and 0 deletions

View File

@ -891,6 +891,104 @@ OpFoldResult OrOp::fold(ArrayRef<Attribute> constants) {
constants, [](APInt &a, const APInt &b) { a |= b; });
}
/// Simplify concat ops in an or op when a constant operand is present in either
/// concat.
///
/// This will invert an or(concat, concat) into concat(or, or, ...), which can
/// often be further simplified due to the smaller or ops being easier to fold.
///
/// For example:
///
/// or(..., concat(x, 0), concat(0, y))
/// ==> or(..., concat(x, 0, y)), when x and y don't overlap.
///
/// or(..., concat(x: i2, cst1: i4), concat(cst2: i5, y: i1))
/// ==> or(..., concat(or(x: i2, extract(cst2, 4..3)),
/// or(extract(cst1, 3..1), extract(cst2, 2..0)),
/// or(extract(cst1, 0..0), y: i1))
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
size_t concatIdx2,
PatternRewriter &rewriter) {
assert(concatIdx1 < concatIdx2 && "concatIdx1 must be < concatIdx2");
auto inputs = op.inputs();
auto concat1 = inputs[concatIdx1].getDefiningOp<ConcatOp>();
auto concat2 = inputs[concatIdx2].getDefiningOp<ConcatOp>();
assert(concat1 && concat2 && "expected indexes to point to ConcatOps");
// We can simplify as long as a constant is present in either concat.
bool hasConstantOp1 =
llvm::any_of(concat1->getOperands(), [&](Value operand) -> bool {
return operand.getDefiningOp<hw::ConstantOp>();
});
bool hasConstantOp2 =
llvm::any_of(concat2->getOperands(), [&](Value operand) -> bool {
return operand.getDefiningOp<hw::ConstantOp>();
});
if (!hasConstantOp1 && !hasConstantOp2)
return false;
SmallVector<Value> newConcatOperands;
// Simultaneously iterate over the operands of both concat ops, from MSB to
// LSB, pushing out or's of overlapping ranges of the operands. When operands
// span different bit ranges, we extract only the maximum overlap.
auto operands1 = concat1->getOperands();
auto operands2 = concat2->getOperands();
// Number of bits already consumed from operands 1 and 2, respectively.
unsigned consumedWidth1 = 0;
unsigned consumedWidth2 = 0;
for (auto it1 = operands1.begin(), end1 = operands1.end(),
it2 = operands2.begin(), end2 = operands2.end();
it1 != end1 && it2 != end2;) {
auto operand1 = *it1;
auto operand2 = *it2;
unsigned remainingWidth1 =
hw::getBitWidth(operand1.getType()) - consumedWidth1;
unsigned remainingWidth2 =
hw::getBitWidth(operand2.getType()) - consumedWidth2;
unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
auto narrowedType = rewriter.getIntegerType(widthToConsume);
auto extract1 = rewriter.createOrFold<ExtractOp>(
op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
auto extract2 = rewriter.createOrFold<ExtractOp>(
op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);
newConcatOperands.push_back(
rewriter.createOrFold<OrOp>(op.getLoc(), extract1, extract2));
consumedWidth1 += widthToConsume;
consumedWidth2 += widthToConsume;
if (widthToConsume == remainingWidth1) {
++it1;
consumedWidth1 = 0;
}
if (widthToConsume == remainingWidth2) {
++it2;
consumedWidth2 = 0;
}
}
ConcatOp newOp = rewriter.create<ConcatOp>(op.getLoc(), newConcatOperands);
// Copy the old operands except for concatIdx1 and concatIdx2, and append the
// new ConcatOp to the end.
SmallVector<Value> newOrOperands;
newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1);
newOrOperands.append(inputs.begin() + concatIdx1 + 1,
inputs.begin() + concatIdx2);
newOrOperands.append(inputs.begin() + concatIdx2 + 1,
inputs.begin() + inputs.size());
newOrOperands.push_back(newOp);
rewriter.replaceOpWithNewOp<OrOp>(op, op.getType(), newOrOperands);
return true;
}
LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
auto inputs = op.inputs();
auto size = inputs.size();
@ -935,6 +1033,16 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
if (tryFlatteningOperands(op, rewriter))
return success();
// or(..., concat(x, cst1), concat(cst2, y)
// ==> or(..., concat(x, cst3, y)), when x and y don't overlap.
for (size_t i = 0; i < size - 1; ++i) {
if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
for (size_t j = i + 1; j < size; ++j)
if (auto concat = inputs[j].getDefiningOp<ConcatOp>())
if (canonicalizeOrOfConcatsWithCstOperands(op, i, j, rewriter))
return success();
}
/// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
return failure();
}

View File

@ -67,6 +67,71 @@ hw.module @andDedupLong(%arg0: i7, %arg1: i7, %arg2: i7) -> (i7) {
hw.output %0 : i7
}
// CHECK-LABEL: hw.module @orExclusiveConcats
hw.module @orExclusiveConcats(%arg0: i6, %arg1: i2) -> (%o: i9) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %0 = comb.concat %arg1, %false, %arg0 : (i2, i1, i6) -> i9
// CHECK-NEXT: hw.output %0 : i9
%c0 = hw.constant 0 : i3
%0 = comb.concat %c0, %arg0 : (i3, i6) -> i9
%c1 = hw.constant 0 : i7
%1 = comb.concat %arg1, %c1 : (i2, i7) -> i9
%2 = comb.or %0, %1 : i9
hw.output %2 : i9
}
// When two concats are or'd together and have mutually-exclusive fields, they
// can be merged together into a single concat.
// concat0: 0aaa aaa0 0000 0bb0
// concat1: 0000 0000 ccdd d000
// merged: 0aaa aaa0 ccdd dbb0
// CHECK-LABEL: hw.module @orExclusiveConcats2
hw.module @orExclusiveConcats2(%arg0: i6, %arg1: i2, %arg2: i2, %arg3: i3) -> (%o: i16) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %0 = comb.concat %false, %arg0, %false, %arg2, %arg3, %arg1, %false : (i1, i6, i1, i2, i3, i2, i1) -> i16
// CHECK-NEXT: hw.output %0 : i16
%c0 = hw.constant 0 : i1
%c1 = hw.constant 0 : i6
%c2 = hw.constant 0 : i1
%0 = comb.concat %c0, %arg0, %c1, %arg1, %c2: (i1, i6, i6, i2, i1) -> i16
%c3 = hw.constant 0 : i8
%c4 = hw.constant 0 : i3
%1 = comb.concat %c3, %arg2, %arg3, %c4 : (i8, i2, i3, i3) -> i16
%2 = comb.or %0, %1 : i16
hw.output %2 : i16
}
// When two concats are or'd together and have mutually-exclusive fields, they
// can be merged together into a single concat.
// concat0: aaaa 1111
// concat1: 1111 10bb
// merged: 1111 1111
// CHECK-LABEL: hw.module @orExclusiveConcats3
hw.module @orExclusiveConcats3(%arg0: i4, %arg1: i2) -> (%o: i8) {
// CHECK-NEXT: [[RES:%[a-z0-9_-]+]] = hw.constant -1 : i8
// CHECK-NEXT: hw.output [[RES]] : i8
%c0 = hw.constant -1 : i4
%0 = comb.concat %arg0, %c0: (i4, i4) -> i8
%c1 = hw.constant -1 : i5
%c2 = hw.constant 0 : i1
%1 = comb.concat %c1, %c2, %arg1 : (i5, i1, i2) -> i8
%2 = comb.or %0, %1 : i8
hw.output %2 : i8
}
// CHECK-LABEL: hw.module @orMultipleExclusiveConcats
hw.module @orMultipleExclusiveConcats(%arg0: i2, %arg1: i2, %arg2: i2) -> (%o: i6) {
// CHECK-NEXT: %0 = comb.concat %arg0, %arg1, %arg2 : (i2, i2, i2) -> i6
// CHECK-NEXT: hw.output %0 : i6
%c2 = hw.constant 0 : i2
%c4 = hw.constant 0 : i4
%0 = comb.concat %arg0, %c4: (i2, i4) -> i6
%1 = comb.concat %c2, %arg1, %c2: (i2, i2, i2) -> i6
%2 = comb.concat %c4, %arg2: (i4, i2) -> i6
%out = comb.or %0, %1, %2 : i6
hw.output %out : i6
}
// CHECK-LABEL: @extractNested
hw.module @extractNested(%0: i5) -> (%o1 : i1) {
// Multiple layers of nested extract is a weak evidence that the cannonicalization