[MLIR] Generalize select to arithmetic canonicalization
Given a select whose result is an i1, we can eliminate the conditional in the select completely by adding a few arithmetic operations. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D116839
This commit is contained in:
parent
720c48b58e
commit
a02af37560
|
@ -813,29 +813,31 @@ static LogicalResult verify(ReturnOp op) {
|
|||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Transforms a select to a not, where relevant.
|
||||
// Transforms a select of a boolean to arithmetic operations
|
||||
//
|
||||
// select %arg, %false, %true
|
||||
// select %arg, %x, %y : i1
|
||||
//
|
||||
// becomes
|
||||
//
|
||||
// xor %arg, %true
|
||||
struct SelectToNot : public OpRewritePattern<SelectOp> {
|
||||
// and(%arg, %x) or and(!%arg, %y)
|
||||
struct SelectI1Simplify : public OpRewritePattern<SelectOp> {
|
||||
using OpRewritePattern<SelectOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SelectOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!matchPattern(op.getTrueValue(), m_Zero()))
|
||||
return failure();
|
||||
|
||||
if (!matchPattern(op.getFalseValue(), m_One()))
|
||||
return failure();
|
||||
|
||||
if (!op.getType().isInteger(1))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<arith::XOrIOp>(op, op.getCondition(),
|
||||
op.getFalseValue());
|
||||
Value falseConstant =
|
||||
rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
|
||||
Value notCondition = rewriter.create<arith::XOrIOp>(
|
||||
op.getLoc(), op.getCondition(), falseConstant);
|
||||
|
||||
Value trueVal = rewriter.create<arith::AndIOp>(
|
||||
op.getLoc(), op.getCondition(), op.getTrueValue());
|
||||
Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
|
||||
op.getFalseValue());
|
||||
rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -876,7 +878,7 @@ struct SelectToExtUI : public OpRewritePattern<SelectOp> {
|
|||
|
||||
void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<SelectToNot, SelectToExtUI>(context);
|
||||
results.insert<SelectI1Simplify, SelectToExtUI>(context);
|
||||
}
|
||||
|
||||
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
|
|
@ -88,10 +88,23 @@ func @branchCondProp(%arg0: i1) {
|
|||
|
||||
// CHECK-LABEL: @selToNot
|
||||
// CHECK: %[[trueval:.+]] = arith.constant true
|
||||
// CHECK: %{{.+}} = arith.xori %arg0, %[[trueval]] : i1
|
||||
// CHECK: %[[res:.+]] = arith.xori %arg0, %[[trueval]] : i1
|
||||
// CHECK: return %[[res]]
|
||||
func @selToNot(%arg0: i1) -> i1 {
|
||||
%true = arith.constant true
|
||||
%false = arith.constant false
|
||||
%res = select %arg0, %false, %true : i1
|
||||
return %res : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @selToArith
|
||||
// CHECK-NEXT: %[[trueval:.+]] = arith.constant true
|
||||
// CHECK-NEXT: %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1
|
||||
// CHECK-NEXT: %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1
|
||||
// CHECK-NEXT: %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1
|
||||
// CHECK-NEXT: %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1
|
||||
// CHECK: return %[[res]]
|
||||
func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
|
||||
%res = select %arg0, %arg1, %arg2 : i1
|
||||
return %res : i1
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue