[Arc] Add VectorizeOp canonicalization (#7146)

This commit is contained in:
elhewaty 2024-07-03 02:29:04 +03:00 committed by GitHub
parent f6ee408e22
commit dbb07f3aff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 262 additions and 1 deletions

View File

@ -17,6 +17,7 @@
#include "circt/Dialect/Seq/SeqOps.h" #include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Support/Namespace.h" #include "circt/Support/Namespace.h"
#include "circt/Support/SymCache.h" #include "circt/Support/SymCache.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@ -259,6 +260,12 @@ struct SinkArcInputsPattern : public SymOpRewritePattern<DefineOp> {
PatternRewriter &rewriter) const final; PatternRewriter &rewriter) const final;
}; };
struct MergeVectorizeOps : public OpRewritePattern<VectorizeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(VectorizeOp op,
PatternRewriter &rewriter) const final;
};
} // namespace } // namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -577,6 +584,79 @@ CompRegCanonicalizer::matchAndRewrite(seq::CompRegOp op,
return success(); return success();
} }
LogicalResult
MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
PatternRewriter &rewriter) const {
auto &currentBlock = vecOp.getBody().front();
IRMapping argMapping;
SmallVector<Value> newOperands;
SmallVector<VectorizeOp> vecOpsToRemove;
bool canBeMerged = false;
// Used to calculate the new positions of args after insertions and removals
unsigned paddedBy = 0;
for (unsigned argIdx = 0, numArgs = vecOp.getInputs().size();
argIdx < numArgs; ++argIdx) {
auto inputVec = vecOp.getInputs()[argIdx];
// Make sure that the input comes from a `VectorizeOp`
// Ensure that the input vector matches the output of the `otherVecOp`
// Make sure that the results of the otherVecOp have only one use
auto otherVecOp = inputVec[0].getDefiningOp<VectorizeOp>();
if (!otherVecOp || inputVec != otherVecOp.getResults() ||
otherVecOp == vecOp ||
!llvm::all_of(otherVecOp.getResults(),
[](auto result) { return result.hasOneUse(); })) {
newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
continue;
}
// If this flag is set that means we changed the IR so we cannot return
// failure
canBeMerged = true;
newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
otherVecOp.getOperands().end());
auto &otherBlock = otherVecOp.getBody().front();
for (auto &otherArg : otherBlock.getArguments()) {
auto newArg = currentBlock.insertArgument(
argIdx + paddedBy, otherArg.getType(), otherArg.getLoc());
argMapping.map(otherArg, newArg);
++paddedBy;
}
rewriter.setInsertionPointToStart(&currentBlock);
for (auto &op : otherBlock.without_terminator())
rewriter.clone(op, argMapping);
unsigned argNewPos = paddedBy + argIdx;
// Get the result of the return value and use it in all places the
// the `otherVecOp` results were used
auto retOp = cast<VectorizeReturnOp>(otherBlock.getTerminator());
rewriter.replaceAllUsesWith(currentBlock.getArgument(argNewPos),
argMapping.lookupOrDefault(retOp.getValue()));
currentBlock.eraseArgument(argNewPos);
vecOpsToRemove.push_back(otherVecOp);
// We erased an arg so the padding decreased by 1
paddedBy--;
}
// We didn't change the IR as there were no vectors to merge
if (!canBeMerged)
return failure();
// Set the new inputOperandSegments value
unsigned groupSize = vecOp.getResults().size();
unsigned numOfGroups = newOperands.size() / groupSize;
SmallVector<int32_t> newAttr(numOfGroups, groupSize);
vecOp.setInputOperandSegments(newAttr);
vecOp.getOperation()->setOperands(ValueRange(newOperands));
// Erase dead VectorizeOps
for (auto deadOp : vecOpsToRemove)
rewriter.eraseOp(deadOp);
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ArcCanonicalizerPass implementation // ArcCanonicalizerPass implementation
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -624,7 +704,8 @@ void ArcCanonicalizerPass::runOnOperation() {
dialect->getCanonicalizationPatterns(patterns); dialect->getCanonicalizationPatterns(patterns);
for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations()) for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations())
op.getCanonicalizationPatterns(patterns, &ctxt); op.getCanonicalizationPatterns(patterns, &ctxt);
patterns.add<ICMPCanonicalizer, CompRegCanonicalizer>(&getContext()); patterns.add<ICMPCanonicalizer, CompRegCanonicalizer, MergeVectorizeOps>(
&getContext());
// Don't test for convergence since it is often not reached. // Don't test for convergence since it is often not reached.
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),

View File

@ -295,3 +295,183 @@ hw.module @DontSinkDifferentConstants1(in %x: i4, out out0: i4, out out1: i4, ou
hw.output %0, %1, %2 : i4, i4, i4 hw.output %0, %1, %2 : i4, i4, i4
} }
// CHECK-NEXT: } // CHECK-NEXT: }
//===----------------------------------------------------------------------===//
// MergeVectorizeOps
//===----------------------------------------------------------------------===//
hw.module @VecOpTest(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8,
in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1,
in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) {
%R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.add %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%L:4 = arc.vectorize(%R#0, %R#1, %R#2, %R#3), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.and %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%C:4 = arc.vectorize(%L#0, %L#1, %L#2, %L#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0 : i8, %arg1: i8):
%1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8
arc.vectorize.return %1692 : i8
}
%4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8
}
arc.define @FooMux(%arg0: i1, %arg1: i8, %arg2: i8) -> i8 {
%0 = comb.mux bin %arg0, %arg1, %arg2 : i8
arc.output %0 : i8
}
arc.define @Just_A_Dummy_Func(%arg0: i8, %arg1: i8) -> i8 {
%0 = comb.or %arg0, %arg1: i8
arc.output %0 : i8
}
// CHECK-LABEL: hw.module @VecOpTest(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) {
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l), (%n, %p, %r, %t), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8):
// CHECK-NEXT: [[ADD0:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: [[AND:%.+]] = comb.and [[ADD0]], %arg2 : i8
// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg3) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }
hw.module @Test_2_in_1(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8,
in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1,
in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) {
%R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.add %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%L:4 = arc.vectorize(%o, %v, %q, %s), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.and %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%C:4 = arc.vectorize(%R#0, %R#1, %R#2, %R#3), (%L#0, %L#1, %L#2, %L#3) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0 : i8, %arg1: i8):
%1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8
arc.vectorize.return %1692 : i8
}
%4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8
}
// CHECK-LABEL: hw.module @Test_2_in_1(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) {
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l), (%o, %v, %q, %s), (%n, %p, %r, %t) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT ^bb0(%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8):
// CHECK-NEXT [[AND:%.+]] = comb.and %arg2, %arg3 : i8
// CHECK-NEXT [[ADD:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[ADD]], [[AND]]) : (i8, i8) -> i8
// CHECK-NEXT arc.vectorize.return [[CALL]] : i8
// CHECK-NEXT }
// CHECK-NEXT [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE:%.+]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT hw.output
// CHECK-NEXT }
hw.module @More_Than_One_Use(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8,
in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1,
in %clock: !seq.clock) {
%R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.add %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%L:4 = arc.vectorize(%R#0, %R#1, %R#2, %R#3), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.and %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%C:4 = arc.vectorize(%L#0, %L#1, %L#2, %L#3), (%R#0, %R#1, %R#2, %R#3) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0 : i8, %arg1: i8):
%1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8
arc.vectorize.return %1692 : i8
}
%4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8
}
// CHECK-LABEL: hw.module @More_Than_One_Use(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock) {
// CHECK-NEXT: [[VEC0:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: arc.vectorize.return [[ADD]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC1:%.+]]:4 = arc.vectorize ([[VEC0]]#0, [[VEC0]]#1, [[VEC0]]#2, [[VEC0]]#3), (%n, %p, %r, %t), ([[VEC0]]#0, [[VEC0]]#1, [[VEC0]]#2, [[VEC0]]#3) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8):
// CHECK-NEXT: [[AND:%.+]] = comb.and %arg0, %arg1 : i8
// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg2) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC1]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }
arc.define @TLMonitor_14_arc(%arg0: i3) -> i3 {
arc.output %arg0 : i3
}
hw.module private @Self_Use(in %clock : !seq.clock) {
%0:2 = arc.vectorize (%clock, %clock), (%0#0, %0#1) : (!seq.clock, !seq.clock, i3, i3) -> (i3, i3) {
^bb0(%arg0: !seq.clock, %arg1: i3):
%1 = arc.state @TLMonitor_14_arc(%arg1) clock %arg0 latency 1 : (i3) -> i3
arc.vectorize.return %1 : i3
}
hw.output
}
// CHECK-LABEL: hw.module private @Self_Use(in %clock : !seq.clock) {
// CHECK-NEXT: [[VEC:%.+]]:2 = arc.vectorize (%clock, %clock), ([[VEC:%.+]]#0, [[VEC:%.+]]#1) : (!seq.clock, !seq.clock, i3, i3) -> (i3, i3) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: !seq.clock, %arg1: i3):
// CHECK-NEXT: [[RET:%.+]] = arc.state @TLMonitor_14_arc(%arg1) clock %arg0 latency 1 : (i3) -> i3
// CHECK-NEXT: arc.vectorize.return [[RET]] : i3
// CHECK-NEXT: }
// CHECK-NEXT: hw.output
// CHECK-NEXT: }
hw.module @Needs_Shuffle(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8,
in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1,
in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) {
%R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.add %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%L:4 = arc.vectorize(%R#1, %R#0, %R#2, %R#3), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.and %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%C:4 = arc.vectorize(%L#1, %L#0, %L#2, %L#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0 : i8, %arg1: i8):
%1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8
arc.vectorize.return %1692 : i8
}
%4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8
}
// CHECK-LABEL: hw.module @Needs_Shuffle(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) {
// CHECK-NEXT: [[VEC0:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[OUT:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC1:%.+]]:4 = arc.vectorize ([[VEC0]]#1, [[VEC0]]#0, [[VEC0]]#2, [[VEC0]]#3), (%n, %p, %r, %t) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[OUT:%.+]] = comb.and %arg0, %arg1 : i8
// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC2:%.+]]:4 = arc.vectorize ([[VEC1]]#1, [[VEC1]]#0, [[VEC1]]#2, [[VEC1]]#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[OUT:%.+]] = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC2]]#0, [[STATE:%.+]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }