[Comb] Removed range merging (#2943)

Due to the semantics of 'x' values, range merging is incorrect and it is hereby banished from the Comb dialect.
This commit is contained in:
Nandor Licker 2022-04-21 19:11:44 +03:00 committed by GitHub
parent de9b60af9d
commit e4cebcf54c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 0 additions and 463 deletions

View File

@ -1000,217 +1000,6 @@ static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
return true;
}
/// Identifies integer range checks and simplifies them to minimise comparisons.
///
/// In an 'or' operation, tests of a variable belonging to a set of disjoint
/// intervals are identified and merged whenever possible. 3 distinct patterns
/// are considered to identify the intervals:
///
/// - icmp(ult, x, n), constrains x < n
/// - icmp(eq, x, n), constrains x = n
/// - and(icmp(ugt, x, n), icmp(ult, x, m)) constrains n < x < m
///
/// For an individual value x, all the comparisons are collected and new
/// comparisons are inserted to represent the intervals with fewer operations.
///
/// If a sequence of equality checks or(x == n, x == n + 1, ...) is detected,
/// an interval is emitted only when at least 3 equality checks are folded into
/// two unsigned comparisons.
///
/// All constants and operations operate in unsigned mode.
///
static bool tryMergeRanges(OrOp op, PatternRewriter &rewriter) {
// Wrapper around an operand representing an interval.
struct Interval {
/// Index of the input from which the check was extracted.
unsigned index;
/// Inclusive lower bound.
APInt lowerBound;
/// Inclusive upper bound.
APInt upperBound;
/// Returns true if this is an equality check.
bool isEqCheck() const { return lowerBound == upperBound; }
};
// Identify all the relevant patterns among the inputs,
// mark all others to be retained unchanged.
auto inputs = op.inputs();
llvm::SmallBitVector keptOperands(inputs.size());
DenseMap<Value, SmallVector<Interval>> argChecks;
for (unsigned i = 0, n = inputs.size(); i < n; ++i) {
auto input = inputs[i];
// Find bound checks: and(x > n, x < m)
auto andOp = input.getDefiningOp<AndOp>();
if (andOp && andOp.inputs().size() == 2) {
ICmpOp lhsOp = andOp.inputs()[0].getDefiningOp<ICmpOp>();
ICmpOp rhsOp = andOp.inputs()[1].getDefiningOp<ICmpOp>();
if (lhsOp && rhsOp && lhsOp.lhs() == rhsOp.lhs()) {
Value arg = lhsOp.lhs();
APInt lhsBound, rhsBound;
if (matchPattern(lhsOp.rhs(), m_RConstant(lhsBound)) &&
matchPattern(rhsOp.rhs(), m_RConstant(rhsBound))) {
ICmpPredicate lhsPred = lhsOp.predicate();
ICmpPredicate rhsPred = rhsOp.predicate();
// If the lower bound is all ones or the upper bound is all
// zeros, the interval is empty and the comparisons are eliminated
// through other canonicalisers.
if (lhsPred == ICmpPredicate::ugt && rhsPred == ICmpPredicate::ult) {
if (!lhsBound.isAllOnes() && !rhsBound.isZero()) {
argChecks[arg].emplace_back(
Interval{i, lhsBound + 1, rhsBound - 1});
}
continue;
}
if (lhsPred == ICmpPredicate::ult && rhsPred == ICmpPredicate::ugt) {
if (!rhsBound.isAllOnes() && !lhsBound.isZero()) {
argChecks[arg].emplace_back(
Interval{i, rhsBound + 1, lhsBound - 1});
}
continue;
}
}
}
}
if (auto cmpOp = input.getDefiningOp<ICmpOp>()) {
APInt v;
if (matchPattern(cmpOp.rhs(), m_RConstant(v))) {
// Find equality tests: x == n
if (cmpOp.predicate() == ICmpPredicate::eq) {
argChecks[cmpOp.lhs()].emplace_back(Interval{i, v, v});
continue;
}
// Find upper bound tests: x < n
if (cmpOp.predicate() == ICmpPredicate::ult) {
if (!v.isZero()) {
argChecks[cmpOp.lhs()].emplace_back(
Interval{i, APInt::getZero(v.getBitWidth()), v - 1});
}
continue;
}
}
}
keptOperands[i] = true;
}
// For each value, try to compress the associated checks.
using RangeCheck = std::tuple<Location, Optional<APInt>, Optional<APInt>>;
llvm::DenseMap<Value, SmallVector<RangeCheck>> newChecks;
bool foldsToTrue = false;
for (auto &[arg, checks] : argChecks) {
// Order the checks by their lower bounds.
std::stable_sort(checks.begin(), checks.end(), [&](auto &lhs, auto &rhs) {
return lhs.lowerBound.ult(rhs.lowerBound);
});
// Find sequences of overlapping checks and compress them.
for (auto *it = checks.begin(); it != checks.end();) {
auto *begin = it++;
auto lowerBound = begin->lowerBound;
APInt upperBound = begin->upperBound;
while (it != checks.end()) {
if (!upperBound.isAllOnes() && !it->lowerBound.ule(upperBound + 1))
break;
APInt itBound = it->upperBound;
if (itBound.ugt(upperBound))
upperBound = itBound;
++it;
}
// If there is no overlap or the range consists of only two
// consecutive numbers, do not create a new range.
size_t n = std::distance(begin, it);
if (n == 1 ||
(n == 2 && begin->isEqCheck() && std::next(begin)->isEqCheck())) {
while (begin != it) {
keptOperands[begin->index] = true;
++begin;
}
continue;
}
// Fuse the locations of all the operands.
llvm::SmallVector<Location> locations;
for (auto *op = begin; op != it; ++op)
locations.push_back(inputs[op->index].getLoc());
// Build the check for the upper bound.
LocationAttr loc = rewriter.getFusedLoc(locations);
if (lowerBound.isZero() && upperBound.isAllOnes())
foldsToTrue = true;
else if (lowerBound.isZero())
newChecks[arg].emplace_back(loc, llvm::None, upperBound + 1);
else if (upperBound.isAllOnes())
newChecks[arg].emplace_back(loc, lowerBound - 1, llvm::None);
else
newChecks[arg].emplace_back(loc, lowerBound - 1, upperBound + 1);
}
}
// If any of the arguments had checks exhaustively covering its
// range, fold the entire or op to true.
if (foldsToTrue) {
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, rewriter.getI1Type(), true);
return true;
}
// Do not change the op if no checks were rewritten or no
// redundant conditions were eliminated.
if (newChecks.empty() && inputs.size() == keptOperands.count())
return false;
SmallVector<Value> newInputs;
for (unsigned i = 0, n = inputs.size(); i < n; ++i)
if (keptOperands[i])
newInputs.push_back(inputs[i]);
// Build range checks for all the checks on all arguments.
for (auto &[arg, checks] : newChecks) {
for (auto &[loc, lowerBound, upperBound] : checks) {
Value upperBoundCheck;
if (upperBound) {
upperBoundCheck = rewriter.create<ICmpOp>(
loc, rewriter.getI1Type(), ICmpPredicate::ult, arg,
rewriter.create<hw::ConstantOp>(loc, *upperBound));
}
Value lowerBoundCheck;
if (lowerBound) {
lowerBoundCheck = rewriter.create<ICmpOp>(
loc, rewriter.getI1Type(), ICmpPredicate::ugt, arg,
rewriter.create<hw::ConstantOp>(loc, *lowerBound));
}
if (lowerBoundCheck && upperBoundCheck)
newInputs.push_back(
rewriter.create<AndOp>(loc, lowerBoundCheck, upperBoundCheck));
else if (lowerBoundCheck)
newInputs.push_back(lowerBoundCheck);
else if (upperBoundCheck)
newInputs.push_back(upperBoundCheck);
else
llvm_unreachable("empty range");
}
}
if (newInputs.size() == 1) {
rewriter.replaceOp(op, newInputs[0]);
} else {
assert(newInputs.size() > 1 && "OrOp expects at least two inputs");
rewriter.replaceOpWithNewOp<OrOp>(op, op.getType(), newInputs);
}
return true;
}
LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
auto inputs = op.inputs();
auto size = inputs.size();
@ -1265,12 +1054,6 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
return success();
}
// or(eq(x, n), eq(x, n + 1), ...) -> or(and(n <= x, x <= n + ...), ...)
// or(and(a < x, x < b), and(c < x, x < d)) -> or(and(a <= x, x < c)) if c < b
// or(and(a < x, x < b), x = b) -> or(and(a < x, x < b + 1))
if (tryMergeRanges(op, rewriter))
return success();
// extracts only of or(...) -> or(extract()...)
if (narrowOperationWidth(op, true, rewriter))
return success();

View File

@ -1,246 +0,0 @@
// RUN: circt-opt -canonicalize='top-down=true region-simplify=true' %s | FileCheck %s
// CHECK-LABEL: @collapse_or
// CHECK-DAG: [[START_A:%.*]] = hw.constant 99 : i32
// CHECK-DAG: [[END_A:%.*]] = hw.constant 106 : i32
// CHECK-DAG: [[CHECK_START_A:%.*]] = comb.icmp ugt %arg, [[START_A]] : i32
// CHECK-DAG: [[CHECK_END_A:%.*]] = comb.icmp ult %arg, [[END_A]] : i32
// CHECK-DAG: [[START_B:%.*]] = hw.constant 1009 : i32
// CHECK-DAG: [[END_B:%.*]] = hw.constant 1015 : i32
// CHECK-DAG: [[CHECK_START_B:%.*]] = comb.icmp ugt %arg, [[START_B]] : i32
// CHECK-DAG: [[CHECK_END_B:%.*]] = comb.icmp ult %arg, [[END_B]] : i32
// CHECK-DAG: [[RANGE_B:%.*]] = comb.and [[CHECK_START_B]], [[CHECK_END_B]] : i1
// CHECK-DAG: [[RANGE_A:%.*]] = comb.and [[CHECK_START_A]], [[CHECK_END_A]] : i1
// CHECK-DAG: [[RESULT:%.*]] = comb.or [[RANGE_A]], [[RANGE_B]] : i1
// CHECK-DAG: hw.output [[RESULT]] : i1
hw.module @collapse_or(%arg: i32) -> (cond: i1) {
%cst0 = hw.constant 100 : i32
%is0 = comb.icmp eq %cst0, %arg : i32
%cst1 = hw.constant 101 : i32
%is1 = comb.icmp eq %cst1, %arg : i32
%cst2 = hw.constant 102 : i32
%is2 = comb.icmp eq %arg, %cst2 : i32
%cst3 = hw.constant 103 : i32
%is3 = comb.icmp eq %arg, %cst3 : i32
%cst4 = hw.constant 104 : i32
%is4 = comb.icmp eq %arg, %cst4 : i32
%cst5 = hw.constant 105 : i32
%is5 = comb.icmp eq %arg, %cst5 : i32
%cst10 = hw.constant 1010 : i32
%is10 = comb.icmp eq %cst10, %arg : i32
%cst11 = hw.constant 1011 : i32
%is11 = comb.icmp eq %arg, %cst11 : i32
%cst12 = hw.constant 1012 : i32
%is12 = comb.icmp eq %arg, %cst12 : i32
%cst13 = hw.constant 1013 : i32
%is13 = comb.icmp eq %arg, %cst13 : i32
%cst14 = hw.constant 1014 : i32
%is14 = comb.icmp eq %cst14, %arg : i32
%in_range = comb.or %is0, %is1, %is3, %is4, %is2, %is5, %is11, %is10, %is12, %is13, %is14 : i1
hw.output %in_range : i1
}
// CHECK-LABEL: collapse_or_chain
// CHECK-DAG: [[START:%.*]] = hw.constant 106 : i32
// CHECK-DAG: [[END:%.*]] = hw.constant 99 : i32
// CHECK-DAG: [[CHECK_START:%.*]] = comb.icmp ugt %arg, [[END]] : i32
// CHECK-DAG: [[CHECK_END:%.*]] = comb.icmp ult %arg, [[START]] : i32
// CHECK-DAG: [[RESULT:%.*]] = comb.and [[CHECK_START]], [[CHECK_END]] : i1
// CHECK-DAG: hw.output [[RESULT]] : i1
hw.module @collapse_or_chain(%arg: i32) -> (cond: i1) {
%cst0 = hw.constant 100 : i32
%is0 = comb.icmp eq %cst0, %arg : i32
%cst1 = hw.constant 101 : i32
%is1 = comb.icmp eq %cst1, %arg : i32
%cst2 = hw.constant 102 : i32
%is2 = comb.icmp eq %arg, %cst2 : i32
%cst3 = hw.constant 103 : i32
%is3 = comb.icmp eq %arg, %cst3 : i32
%cst4 = hw.constant 104 : i32
%is4 = comb.icmp eq %arg, %cst4 : i32
%cst5 = hw.constant 105 : i32
%is5 = comb.icmp eq %arg, %cst5 : i32
%is0_1 = comb.or %is0, %is1 : i1
%is2_3 = comb.or %is2, %is3 : i1
%is4_5 = comb.or %is4, %is5 : i1
%is0_3 = comb.or %is0_1, %is2_3 : i1
%is0_5 = comb.or %is0_3, %is4_5 : i1
hw.output %is0_5 : i1
}
// CHECK-LABEL: @merge_ranges
// CHECK-DAG: [[START:%.*]] = hw.constant 300 : i32
// CHECK-DAG: [[END:%.*]] = hw.constant 100 : i32
// CHECK-DAG: [[CHECK_START:%.*]] = comb.icmp ugt %arg, [[END]] : i32
// CHECK-DAG: [[CHECK_END:%.*]] = comb.icmp ult %arg, [[START]] : i32
// CHECK-DAG: [[RESULT:%.*]] = comb.and [[CHECK_START]], [[CHECK_END]] : i1
// CHECK-DAG: hw.output [[RESULT]] : i1
hw.module @merge_ranges(%arg: i32) -> (cond: i1) {
%start_a = hw.constant 100 : i32
%end_a = hw.constant 200 : i32
%check_start_a = comb.icmp ugt %arg, %start_a : i32
%check_end_a = comb.icmp ult %arg, %end_a : i32
%in_a = comb.and %check_start_a, %check_end_a : i1
%start_b = hw.constant 199 : i32
%end_b = hw.constant 300 : i32
%check_start_b = comb.icmp ugt %arg, %start_b : i32
%check_end_b = comb.icmp ult %arg, %end_b : i32
%in_b = comb.and %check_end_b, %check_start_b : i1
%in_range = comb.or %in_a, %in_b : i1
hw.output %in_range : i1
}
// CHECK-LABEL: @extend_ranges
// CHECK-DAG: [[START:%.*]] = hw.constant 201 : i32
// CHECK-DAG: [[END:%.*]] = hw.constant 100 : i32
// CHECK-DAG: [[CHECK_START:%.*]] = comb.icmp ugt %arg, [[END]] : i32
// CHECK-DAG: [[CHECK_END:%.*]] = comb.icmp ult %arg, [[START]] : i32
// CHECK-DAG: [[RESULT:%.*]] = comb.and [[CHECK_START]], [[CHECK_END]] : i1
// CHECK-DAG: hw.output [[RESULT]] : i1
hw.module @extend_ranges(%arg: i32) -> (cond: i1) {
%start_a = hw.constant 100 : i32
%end_a = hw.constant 200 : i32
%check_start_a = comb.icmp ugt %arg, %start_a : i32
%check_end_a = comb.icmp ult %arg, %end_a : i32
%in_a = comb.and %check_start_a, %check_end_a : i1
%elem = hw.constant 200 : i32
%eq_elem = comb.icmp eq %arg, %elem : i32
%in_range = comb.or %in_a, %eq_elem : i1
hw.output %in_range : i1
}
// CHECK-LABEL: @make_lower_bound
// CHECK_DAG: [[BOUND:%.*]] = hw.constant 6 : i32
// CHECK_DAG: [[RESULT:%.*]] = comb.icmp ult %arg, [[BOUND]] : i32
// CHECK_DAG: hw.output [[RESULT]] : i1
hw.module @make_lower_bound(%arg: i32) -> (cond: i1) {
%cst0 = hw.constant 0 : i32
%is0 = comb.icmp eq %cst0, %arg : i32
%cst1 = hw.constant 1 : i32
%is1 = comb.icmp eq %cst1, %arg : i32
%cst2 = hw.constant 2 : i32
%is2 = comb.icmp eq %arg, %cst2 : i32
%cst3 = hw.constant 3 : i32
%is3 = comb.icmp eq %arg, %cst3 : i32
%cst4 = hw.constant 4 : i32
%is4 = comb.icmp eq %arg, %cst4 : i32
%cst5 = hw.constant 5 : i32
%is5 = comb.icmp eq %arg, %cst5 : i32
%is0_1 = comb.or %is0, %is1 : i1
%is2_3 = comb.or %is2, %is3 : i1
%is4_5 = comb.or %is4, %is5 : i1
%is0_3 = comb.or %is0_1, %is2_3 : i1
%is0_5 = comb.or %is0_3, %is4_5 : i1
hw.output %is0_5 : i1
}
// CHECK-LABEL: @merge_with_lower_bound
// CHECK-DAG: [[END:%.*]] = hw.constant 30 : i32
// CHECK-DAG: [[CHECK_END:%.*]] = comb.icmp ult %arg, [[END]] : i32
// CHECK-DAG: hw.output [[CHECK_END]] : i1
hw.module @merge_with_lower_bound(%arg: i32) -> (cond: i1) {
%0 = hw.constant 6 : i32
%1 = comb.icmp ult %arg, %0 : i32
%2 = hw.constant 20 : i32
%3 = comb.icmp ult %arg, %2 : i32
%4 = hw.constant 10 : i32
%5 = comb.icmp eq %arg, %4 : i32
%6 = hw.constant 19 : i32
%7 = hw.constant 30 : i32
%8 = comb.icmp ugt %arg, %6 : i32
%9 = comb.icmp ult %arg, %7 : i32
%10 = comb.and %8, %9 : i1
%11 = comb.or %1, %3, %5, %10 : i1
hw.output %11 : i1
}
// CHECK-LABEL: @merge_with_upper_bound
// CHECK-DAG: [[CST:%.+]] = hw.constant -5 : i4
// CHECK-DAG: [[RESULT:%.+]] = comb.icmp ugt %arg, [[CST]] : i4
// CHECK-DAG: hw.output [[RESULT]] : i1
hw.module @merge_with_upper_bound(%arg: i4) -> (cond: i1) {
%cst15 = hw.constant 15 : i4
%is15 = comb.icmp eq %cst15, %arg : i4
%cst14 = hw.constant 14 : i4
%is14 = comb.icmp eq %cst14, %arg : i4
%cst13 = hw.constant 13 : i4
%is13 = comb.icmp eq %arg, %cst13 : i4
%cst12 = hw.constant 12 : i4
%is12 = comb.icmp eq %arg, %cst12 : i4
%in_range = comb.or %is12, %is13, %is14, %is15 : i1
hw.output %in_range : i1
}
// CHECK-LABEL: @merge_range_covers_all
// CHECK-NEXT: [[CST:%.+]] = hw.constant true
// CHECK-NEXT: hw.output [[CST]] : i1
hw.module @merge_range_covers_all(%0: i2) -> (a:i1) {
%c-1_i2 = hw.constant -1 : i2
%c0_i2 = hw.constant 0 : i2
%c-2_i2 = hw.constant -2 : i2
%c1_i2 = hw.constant 1 : i2
%1 = comb.icmp eq %0, %c-2_i2 : i2
%2 = comb.icmp eq %0, %c0_i2 : i2
%4 = comb.icmp eq %0, %c-1_i2 : i2
%5 = comb.icmp eq %0, %c1_i2 : i2
%6 = comb.or %1, %2, %4, %5 : i1
hw.output %6: i1
}
// CHECK-LABEL: @merge_range_covers_all_with_extra_arg
// CHECK-NEXT: [[CST:%.+]] = hw.constant true
// CHECK-NEXT: hw.output [[CST]] : i1
hw.module @merge_range_covers_all_with_extra_arg(%0: i2, %b: i1) -> (a:i1) {
%c-1_i2 = hw.constant -1 : i2
%c0_i2 = hw.constant 0 : i2
%c-2_i2 = hw.constant -2 : i2
%c1_i2 = hw.constant 1 : i2
%1 = comb.icmp eq %0, %c-2_i2 : i2
%2 = comb.icmp eq %0, %c0_i2 : i2
%4 = comb.icmp eq %0, %c-1_i2 : i2
%5 = comb.icmp eq %0, %c1_i2 : i2
%6 = comb.or %1, %2, %4, %5, %b : i1
hw.output %6: i1
}
// CHECK-LABEL: @remove_overlap
// CHECK-DAG: [[START:%.*]] = hw.constant 200 : i32
// CHECK-DAG: [[END:%.*]] = hw.constant 100 : i32
// CHECK-DAG: [[CHECK_START:%.*]] = comb.icmp ugt %arg, [[END]] : i32
// CHECK-DAG: [[CHECK_END:%.*]] = comb.icmp ult %arg, [[START]] : i32
// CHECK-DAG: [[RESULT:%.*]] = comb.and [[CHECK_START]], [[CHECK_END]] : i1
// CHECK-DAG: hw.output [[RESULT]] : i1
hw.module @remove_overlap(%arg: i32) -> (cond: i1) {
%start_a = hw.constant 100 : i32
%end_a = hw.constant 200 : i32
%check_start_a = comb.icmp ugt %arg, %start_a : i32
%check_end_a = comb.icmp ult %arg, %end_a : i32
%in_a = comb.and %check_start_a, %check_end_a : i1
%elem = hw.constant 150 : i32
%eq_elem = comb.icmp eq %arg, %elem : i32
%in_range = comb.or %in_a, %eq_elem : i1
hw.output %in_range : i1
}