From dbb07f3aff01c2dab6fa5eccf60798a8103d06e9 Mon Sep 17 00:00:00 2001 From: elhewaty Date: Wed, 3 Jul 2024 02:29:04 +0300 Subject: [PATCH] [Arc] Add VectorizeOp canonicalization (#7146) --- .../Arc/Transforms/ArcCanonicalizer.cpp | 83 +++++++- test/Dialect/Arc/arc-canonicalizer.mlir | 180 ++++++++++++++++++ 2 files changed, 262 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp b/lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp index 58f682ed21..dc05edf027 100644 --- a/lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp +++ b/lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp @@ -17,6 +17,7 @@ #include "circt/Dialect/Seq/SeqOps.h" #include "circt/Support/Namespace.h" #include "circt/Support/SymCache.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -259,6 +260,12 @@ struct SinkArcInputsPattern : public SymOpRewritePattern { PatternRewriter &rewriter) const final; }; +struct MergeVectorizeOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(VectorizeOp op, + PatternRewriter &rewriter) const final; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -577,6 +584,79 @@ CompRegCanonicalizer::matchAndRewrite(seq::CompRegOp op, return success(); } +LogicalResult +MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp, + PatternRewriter &rewriter) const { + auto ¤tBlock = vecOp.getBody().front(); + IRMapping argMapping; + SmallVector newOperands; + SmallVector vecOpsToRemove; + bool canBeMerged = false; + // Used to calculate the new positions of args after insertions and removals + unsigned paddedBy = 0; + + for (unsigned argIdx = 0, numArgs = vecOp.getInputs().size(); + argIdx < numArgs; ++argIdx) { + auto inputVec = vecOp.getInputs()[argIdx]; + // Make sure that the input comes from a `VectorizeOp` + // Ensure that the input vector matches the output of the `otherVecOp` + // Make sure that the results of the otherVecOp have only one use + auto otherVecOp = inputVec[0].getDefiningOp(); + if (!otherVecOp || inputVec != otherVecOp.getResults() || + otherVecOp == vecOp || + !llvm::all_of(otherVecOp.getResults(), + [](auto result) { return result.hasOneUse(); })) { + newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end()); + continue; + } + // If this flag is set that means we changed the IR so we cannot return + // failure + canBeMerged = true; + newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(), + otherVecOp.getOperands().end()); + + auto &otherBlock = otherVecOp.getBody().front(); + for (auto &otherArg : otherBlock.getArguments()) { + auto newArg = currentBlock.insertArgument( + argIdx + paddedBy, otherArg.getType(), otherArg.getLoc()); + argMapping.map(otherArg, newArg); + ++paddedBy; + } + + rewriter.setInsertionPointToStart(¤tBlock); + for (auto &op : otherBlock.without_terminator()) + rewriter.clone(op, argMapping); + + unsigned argNewPos = paddedBy + argIdx; + // Get the result of the return value and use it in all places the + // the `otherVecOp` results were used + auto retOp = cast(otherBlock.getTerminator()); + rewriter.replaceAllUsesWith(currentBlock.getArgument(argNewPos), + argMapping.lookupOrDefault(retOp.getValue())); + currentBlock.eraseArgument(argNewPos); + vecOpsToRemove.push_back(otherVecOp); + // We erased an arg so the padding decreased by 1 + paddedBy--; + } + + // We didn't change the IR as there were no vectors to merge + if (!canBeMerged) + return failure(); + + // Set the new inputOperandSegments value + unsigned groupSize = vecOp.getResults().size(); + unsigned numOfGroups = newOperands.size() / groupSize; + SmallVector newAttr(numOfGroups, groupSize); + vecOp.setInputOperandSegments(newAttr); + vecOp.getOperation()->setOperands(ValueRange(newOperands)); + + // Erase dead VectorizeOps + for (auto deadOp : vecOpsToRemove) + rewriter.eraseOp(deadOp); + + return success(); +} + //===----------------------------------------------------------------------===// // ArcCanonicalizerPass implementation //===----------------------------------------------------------------------===// @@ -624,7 +704,8 @@ void ArcCanonicalizerPass::runOnOperation() { dialect->getCanonicalizationPatterns(patterns); for (mlir::RegisteredOperationName op : ctxt.getRegisteredOperations()) op.getCanonicalizationPatterns(patterns, &ctxt); - patterns.add(&getContext()); + patterns.add( + &getContext()); // Don't test for convergence since it is often not reached. (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/test/Dialect/Arc/arc-canonicalizer.mlir b/test/Dialect/Arc/arc-canonicalizer.mlir index ba761049c8..88ac327d4f 100644 --- a/test/Dialect/Arc/arc-canonicalizer.mlir +++ b/test/Dialect/Arc/arc-canonicalizer.mlir @@ -295,3 +295,183 @@ hw.module @DontSinkDifferentConstants1(in %x: i4, out out0: i4, out out1: i4, ou hw.output %0, %1, %2 : i4, i4, i4 } // CHECK-NEXT: } + +//===----------------------------------------------------------------------===// +// MergeVectorizeOps +//===----------------------------------------------------------------------===// + +hw.module @VecOpTest(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8, +in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1, +in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) { + %R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.add %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %L:4 = arc.vectorize(%R#0, %R#1, %R#2, %R#3), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.and %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %C:4 = arc.vectorize(%L#0, %L#1, %L#2, %L#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0 : i8, %arg1: i8): + %1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8 + arc.vectorize.return %1692 : i8 + } + %4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8 +} + +arc.define @FooMux(%arg0: i1, %arg1: i8, %arg2: i8) -> i8 { + %0 = comb.mux bin %arg0, %arg1, %arg2 : i8 + arc.output %0 : i8 +} +arc.define @Just_A_Dummy_Func(%arg0: i8, %arg1: i8) -> i8 { + %0 = comb.or %arg0, %arg1: i8 + arc.output %0 : i8 +} + +// CHECK-LABEL: hw.module @VecOpTest(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) { +// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l), (%n, %p, %r, %t), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { +// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8): +// CHECK-NEXT: [[ADD0:%.+]] = comb.add %arg0, %arg1 : i8 +// CHECK-NEXT: [[AND:%.+]] = comb.and [[ADD0]], %arg2 : i8 +// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg3) : (i8, i8) -> i8 +// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8 +// CHECK-NEXT: } +// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8 +// CHECK-NEXT: hw.output +// CHECK-NEXT: } + +hw.module @Test_2_in_1(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8, +in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1, +in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) { + %R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.add %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %L:4 = arc.vectorize(%o, %v, %q, %s), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.and %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %C:4 = arc.vectorize(%R#0, %R#1, %R#2, %R#3), (%L#0, %L#1, %L#2, %L#3) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0 : i8, %arg1: i8): + %1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8 + arc.vectorize.return %1692 : i8 + } + %4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8 +} + +// CHECK-LABEL: hw.module @Test_2_in_1(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) { +// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l), (%o, %v, %q, %s), (%n, %p, %r, %t) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { +// CHECK-NEXT ^bb0(%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8): +// CHECK-NEXT [[AND:%.+]] = comb.and %arg2, %arg3 : i8 +// CHECK-NEXT [[ADD:%.+]] = comb.add %arg0, %arg1 : i8 +// CHECK-NEXT [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[ADD]], [[AND]]) : (i8, i8) -> i8 +// CHECK-NEXT arc.vectorize.return [[CALL]] : i8 +// CHECK-NEXT } +// CHECK-NEXT [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE:%.+]]) clock %clock latency 1 : (i1, i8, i8) -> i8 +// CHECK-NEXT hw.output +// CHECK-NEXT } + + +hw.module @More_Than_One_Use(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8, +in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1, +in %clock: !seq.clock) { + %R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.add %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %L:4 = arc.vectorize(%R#0, %R#1, %R#2, %R#3), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.and %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %C:4 = arc.vectorize(%L#0, %L#1, %L#2, %L#3), (%R#0, %R#1, %R#2, %R#3) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0 : i8, %arg1: i8): + %1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8 + arc.vectorize.return %1692 : i8 + } + %4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8 +} + +// CHECK-LABEL: hw.module @More_Than_One_Use(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock) { +// CHECK-NEXT: [[VEC0:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { +// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8): +// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8 +// CHECK-NEXT: arc.vectorize.return [[ADD]] : i8 +// CHECK-NEXT: } +// CHECK-NEXT: [[VEC1:%.+]]:4 = arc.vectorize ([[VEC0]]#0, [[VEC0]]#1, [[VEC0]]#2, [[VEC0]]#3), (%n, %p, %r, %t), ([[VEC0]]#0, [[VEC0]]#1, [[VEC0]]#2, [[VEC0]]#3) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { +// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8): +// CHECK-NEXT: [[AND:%.+]] = comb.and %arg0, %arg1 : i8 +// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg2) : (i8, i8) -> i8 +// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8 +// CHECK-NEXT: } +// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC1]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8 +// CHECK-NEXT: hw.output +// CHECK-NEXT: } + +arc.define @TLMonitor_14_arc(%arg0: i3) -> i3 { + arc.output %arg0 : i3 +} + +hw.module private @Self_Use(in %clock : !seq.clock) { + %0:2 = arc.vectorize (%clock, %clock), (%0#0, %0#1) : (!seq.clock, !seq.clock, i3, i3) -> (i3, i3) { + ^bb0(%arg0: !seq.clock, %arg1: i3): + %1 = arc.state @TLMonitor_14_arc(%arg1) clock %arg0 latency 1 : (i3) -> i3 + arc.vectorize.return %1 : i3 + } + hw.output +} + +// CHECK-LABEL: hw.module private @Self_Use(in %clock : !seq.clock) { +// CHECK-NEXT: [[VEC:%.+]]:2 = arc.vectorize (%clock, %clock), ([[VEC:%.+]]#0, [[VEC:%.+]]#1) : (!seq.clock, !seq.clock, i3, i3) -> (i3, i3) { +// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: !seq.clock, %arg1: i3): +// CHECK-NEXT: [[RET:%.+]] = arc.state @TLMonitor_14_arc(%arg1) clock %arg0 latency 1 : (i3) -> i3 +// CHECK-NEXT: arc.vectorize.return [[RET]] : i3 +// CHECK-NEXT: } +// CHECK-NEXT: hw.output +// CHECK-NEXT: } + +hw.module @Needs_Shuffle(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8, +in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1, +in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) { + %R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.add %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %L:4 = arc.vectorize(%R#1, %R#0, %R#2, %R#3), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.and %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %C:4 = arc.vectorize(%L#1, %L#0, %L#2, %L#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0 : i8, %arg1: i8): + %1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8 + arc.vectorize.return %1692 : i8 + } + %4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8 +} + +// CHECK-LABEL: hw.module @Needs_Shuffle(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) { +// CHECK-NEXT: [[VEC0:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { +// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8): +// CHECK-NEXT: [[OUT:%.+]] = comb.add %arg0, %arg1 : i8 +// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8 +// CHECK-NEXT: } +// CHECK-NEXT: [[VEC1:%.+]]:4 = arc.vectorize ([[VEC0]]#1, [[VEC0]]#0, [[VEC0]]#2, [[VEC0]]#3), (%n, %p, %r, %t) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { +// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8): +// CHECK-NEXT: [[OUT:%.+]] = comb.and %arg0, %arg1 : i8 +// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8 +// CHECK-NEXT: } +// CHECK-NEXT: [[VEC2:%.+]]:4 = arc.vectorize ([[VEC1]]#1, [[VEC1]]#0, [[VEC1]]#2, [[VEC1]]#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { +// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8): +// CHECK-NEXT: [[OUT:%.+]] = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8 +// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8 +// CHECK-NEXT: } +// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC2]]#0, [[STATE:%.+]]) clock %clock latency 1 : (i1, i8, i8) -> i8 +// CHECK-NEXT: hw.output +// CHECK-NEXT: }