diff --git a/include/circt/Dialect/Calyx/CalyxControl.td b/include/circt/Dialect/Calyx/CalyxControl.td index e28d6e371d..43414c375f 100644 --- a/include/circt/Dialect/Calyx/CalyxControl.td +++ b/include/circt/Dialect/Calyx/CalyxControl.td @@ -68,7 +68,7 @@ def IfOp : CalyxContainer<"if", [ let assemblyFormat = "$cond (`with` $groupName^)? $thenRegion (`else` $elseRegion^)? attr-dict"; let verifier = "return ::verify$cppClass(*this);"; - let hasCanonicalizeMethod = true; + let hasCanonicalizer = true; let extraClassDeclaration = [{ /// Checks whether the `then` body exists. bool thenBodyExists() { diff --git a/lib/Dialect/Calyx/CalyxOps.cpp b/lib/Dialect/Calyx/CalyxOps.cpp index cb514ff623..74676c60bf 100644 --- a/lib/Dialect/Calyx/CalyxOps.cpp +++ b/lib/Dialect/Calyx/CalyxOps.cpp @@ -1416,56 +1416,141 @@ static EnableOp getLastEnableOp(SeqOp parent) { return nullptr; } -/// Removes common tail enable operations for sequential 'then'/'else' -/// branches inside an 'if' operation. -/// -/// if %a with %A { if %a with %A { -/// seq { ... calyx.enable @B } seq { ... } -/// else { -> } else { -/// seq { ... calyx.enable @B } seq { ... } -/// } } -/// calyx.enable @B -static LogicalResult eliminateCommonTailEnable(IfOp ifOp, - PatternRewriter &rewriter) { - // Check if the branches exist. - if (!ifOp.thenBodyExists() || !ifOp.elseBodyExists()) - return failure(); +/// Returns a mapping of {enabled Group name, EnableOp} for all EnableOps within +/// the immediate ParOp's body. +static llvm::StringMap getAllEnableOpsInImmediateBody(ParOp parent) { + llvm::StringMap enables; + Block *body = parent.getBody(); + for (EnableOp op : body->getOps()) + enables.insert(std::pair(op.groupName(), op)); - auto &thenOpStructureOp = ifOp.getThenBody()->front(); - auto &elseOpStructureOp = ifOp.getElseBody()->front(); - // TODO(circt/#1861): ParOps have less restrictive conditions. - if (isa(thenOpStructureOp) || isa(elseOpStructureOp)) - return failure(); - - // At this point, only sequential operations are valid inside the branches. - auto thenSeqOp = dyn_cast(thenOpStructureOp); - auto elseSeqOp = dyn_cast(elseOpStructureOp); - assert(thenSeqOp && elseSeqOp && - "expected nested seq ops in both branches of a calyx.IfOp"); - - EnableOp lastThenEnableOp = getLastEnableOp(thenSeqOp); - EnableOp lastElseEnableOp = getLastEnableOp(elseSeqOp); - - if (lastThenEnableOp == nullptr || lastElseEnableOp == nullptr) - return failure(); - - if (lastThenEnableOp.groupName() != lastElseEnableOp.groupName()) - return failure(); - - // Erase both enable operations and add group enable operation after the - // shared IfOp parent. - rewriter.setInsertionPointAfter(ifOp); - rewriter.create(ifOp.getLoc(), lastThenEnableOp.groupName()); - rewriter.eraseOp(lastThenEnableOp); - rewriter.eraseOp(lastElseEnableOp); - return success(); + return enables; } -LogicalResult IfOp::canonicalize(IfOp ifOp, PatternRewriter &rewriter) { - if (succeeded(eliminateCommonTailEnable(ifOp, rewriter))) - return success(); +/// Checks preconditions for the common tail pattern. This canonicalization is +/// stringent about not entering nested control operations, as this may cause +/// unintentional changes in behavior. +/// We only look for two cases: (1) both regions are ParOps, and +/// (2) both regions are SeqOps. The case when these are different, e.g. ParOp +/// and SeqOp, will only produce less optimal code, or even worse, change the +/// behavior. +template +static bool hasCommonTailPatternPreConditions(IfOp op) { + static_assert(std::is_same() || std::is_same(), + "Should be a SeqOp or ParOp."); - return failure(); + if (!op.thenBodyExists() || !op.elseBodyExists()) + return false; + + Block *thenBody = op.getThenBody(), *elseBody = op.getElseBody(); + return isa(thenBody->front()) && isa(elseBody->front()); +} + +/// seq { +/// if %a with @G { if %a with @G { +/// seq { ... calyx.enable @A } seq { ... } +/// else { -> } else { +/// seq { ... calyx.enable @A } seq { ... } +/// } } +/// calyx.enable @A +/// } +struct CommonTailPatternWithSeq : mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp ifOp, + PatternRewriter &rewriter) const override { + if (!hasCommonTailPatternPreConditions(ifOp)) + return failure(); + + auto thenControl = cast(ifOp.getThenBody()->front()), + elseControl = cast(ifOp.getElseBody()->front()); + EnableOp lastThenEnableOp = getLastEnableOp(thenControl), + lastElseEnableOp = getLastEnableOp(elseControl); + + if (lastThenEnableOp == nullptr || lastElseEnableOp == nullptr) + return failure(); + if (lastThenEnableOp.groupName() != lastElseEnableOp.groupName()) + return failure(); + + // Place the IfOp and pulled EnableOp inside a sequential region, in case + // this IfOp is nested in a ParOp. This avoids unintentionally + // parallelizing the pulled out EnableOps. + rewriter.setInsertionPointAfter(ifOp); + SeqOp seqOp = rewriter.create(ifOp.getLoc()); + rewriter.createBlock(&seqOp.getBodyRegion()); + Block *body = seqOp.getBody(); + ifOp->remove(); + body->push_back(ifOp); + rewriter.create(seqOp.getLoc(), lastThenEnableOp.groupName()); + + // Erase the common EnableOp from the Then and Else regions. + rewriter.eraseOp(lastThenEnableOp); + rewriter.eraseOp(lastElseEnableOp); + return success(); + } +}; + +/// if %a with @G { par { +/// par { if %a with @G { +/// ... par { ... } +/// calyx.enable @A } else { +/// calyx.enable @B -> par { ... } +/// } } +/// } else { calyx.enable @A +/// par { calyx.enable @B +/// ... } +/// calyx.enable @A +/// calyx.enable @B +/// } +/// } +struct CommonTailPatternWithPar : mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp ifOp, + PatternRewriter &rewriter) const override { + if (!hasCommonTailPatternPreConditions(ifOp)) + return failure(); + auto thenControl = cast(ifOp.getThenBody()->front()), + elseControl = cast(ifOp.getElseBody()->front()); + + llvm::StringMap A = getAllEnableOpsInImmediateBody(thenControl), + B = getAllEnableOpsInImmediateBody(elseControl); + + // Compute the intersection between `A` and `B`. + SmallVector groupNames; + for (auto a = A.begin(); a != A.end(); ++a) { + StringRef groupName = a->getKey(); + auto b = B.find(groupName); + if (b == B.end()) + continue; + // This is also an element in B. + groupNames.push_back(groupName); + // Since these are being pulled out, erase them. + rewriter.eraseOp(a->getValue()); + rewriter.eraseOp(b->getValue()); + } + // Place the IfOp and EnableOp(s) inside a parallel region, in case this + // IfOp is nested in a SeqOp. This avoids unintentionally sequentializing + // the pulled out EnableOps. + rewriter.setInsertionPointAfter(ifOp); + ParOp parOp = rewriter.create(ifOp.getLoc()); + rewriter.createBlock(&parOp.getBodyRegion()); + Block *body = parOp.getBody(); + ifOp->remove(); + body->push_back(ifOp); + + // Pull out the intersection between these two sets, and erase their + // counterparts in the Then and Else regions. + for (StringRef groupName : groupNames) + rewriter.create(parOp.getLoc(), groupName); + + return success(); + } +}; + +void IfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/test/Dialect/Calyx/canonicalization.mlir b/test/Dialect/Calyx/canonicalization.mlir index d5f5b71fd6..bc70b867df 100644 --- a/test/Dialect/Calyx/canonicalization.mlir +++ b/test/Dialect/Calyx/canonicalization.mlir @@ -54,7 +54,7 @@ calyx.program "main" { } } -// IfOp removes common tails from within SeqOps. +// IfOp nested in SeqOp removes common tail from within SeqOps. calyx.program "main" { calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%done: i1 {done}) { %r.in, %r.write_en, %r.clk, %r.reset, %r.out, %r.done = calyx.register "r" : i1, i1, i1, i1, i1, i1 @@ -87,7 +87,7 @@ calyx.program "main" { // CHECK-NEXT: calyx.seq { // CHECK-NEXT: calyx.enable @B // CHECK-NEXT: } - // CHECK-NEXT: else { + // CHECK-NEXT: } else { // CHECK-NEXT: calyx.seq { // CHECK-NEXT: calyx.enable @C // CHECK-NEXT: } @@ -112,3 +112,202 @@ calyx.program "main" { } } } + +// IfOp nested in ParOp removes common tails from within ParOps. +calyx.program "main" { + calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%done: i1 {done}) { + %r.in, %r.write_en, %r.clk, %r.reset, %r.out, %r.done = calyx.register "r" : i1, i1, i1, i1, i1, i1 + %eq.left, %eq.right, %eq.out = calyx.std_eq "eq" : i1, i1, i1 + %c1_1 = hw.constant 1 : i1 + calyx.wires { + calyx.comb_group @Cond { + calyx.assign %eq.left = %c1_1 : i1 + calyx.assign %eq.right = %c1_1 : i1 + } + calyx.group @A { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + calyx.group @B { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + calyx.group @C { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + calyx.group @D { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + } + // CHECK-LABEL: calyx.control { + // CHECK-NEXT: calyx.par { + // CHECK-NEXT: calyx.if %eq.out with @Cond { + // CHECK-NEXT: calyx.par { + // CHECK-NEXT: calyx.enable @A + // CHECK-NEXT: } + // CHECK-NEXT: } else { + // CHECK-NEXT: calyx.par { + // CHECK-NEXT: calyx.enable @B + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: calyx.enable @C + // CHECK-NEXT: calyx.enable @D + // CHECK-NEXT: } + // CHECK-NEXT: } + calyx.control { + calyx.par { + calyx.if %eq.out with @Cond { + calyx.par { + calyx.enable @A + calyx.enable @C + calyx.enable @D + } + } else { + calyx.par { + calyx.enable @B + calyx.enable @C + calyx.enable @D + } + } + } + } + } +} + +// IfOp nested in ParOp removes common tail from within SeqOps. The important check +// here is ensuring the removed EnableOps are still computed sequentially. +calyx.program "main" { + calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%done: i1 {done}) { + %r.in, %r.write_en, %r.clk, %r.reset, %r.out, %r.done = calyx.register "r" : i1, i1, i1, i1, i1, i1 + %eq.left, %eq.right, %eq.out = calyx.std_eq "eq" : i1, i1, i1 + %c1_1 = hw.constant 1 : i1 + calyx.wires { + calyx.comb_group @Cond { + calyx.assign %eq.left = %c1_1 : i1 + calyx.assign %eq.right = %c1_1 : i1 + } + calyx.group @A { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + calyx.group @B { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + calyx.group @C { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + } + // CHECK-LABEL: calyx.control { + // CHECK-NEXT: calyx.par { + // CHECK-NEXT: calyx.seq { + // CHECK-NEXT: calyx.if %eq.out with @Cond { + // CHECK-NEXT: calyx.seq { + // CHECK-NEXT: calyx.enable @B + // CHECK-NEXT: } + // CHECK-NEXT: } else { + // CHECK-NEXT: calyx.seq { + // CHECK-NEXT: calyx.enable @C + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: calyx.enable @A + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + calyx.control { + calyx.par { + calyx.if %eq.out with @Cond { + calyx.seq { + calyx.enable @B + calyx.enable @A + } + } else { + calyx.seq { + calyx.enable @C + calyx.enable @A + } + } + } + } + } +} + +// IfOp nested in SeqOp removes common tail from within ParOps. The important check +// here is ensuring the removed EnableOps are still computed in parallel. +calyx.program "main" { + calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%done: i1 {done}) { + %r.in, %r.write_en, %r.clk, %r.reset, %r.out, %r.done = calyx.register "r" : i1, i1, i1, i1, i1, i1 + %eq.left, %eq.right, %eq.out = calyx.std_eq "eq" : i1, i1, i1 + %c1_1 = hw.constant 1 : i1 + calyx.wires { + calyx.comb_group @Cond { + calyx.assign %eq.left = %c1_1 : i1 + calyx.assign %eq.right = %c1_1 : i1 + } + calyx.group @A { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + calyx.group @B { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + calyx.group @C { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + calyx.group @D { + calyx.assign %r.in = %c1_1 : i1 + calyx.assign %r.write_en = %c1_1 : i1 + calyx.group_done %r.done : i1 + } + } + // CHECK-LABEL: calyx.control { + // CHECK-NEXT: calyx.seq { + // CHECK-NEXT: calyx.par { + // CHECK-NEXT: calyx.if %eq.out with @Cond { + // CHECK-NEXT: calyx.par { + // CHECK-NEXT: calyx.enable @A + // CHECK-NEXT: } + // CHECK-NEXT: } else { + // CHECK-NEXT: calyx.par { + // CHECK-NEXT: calyx.enable @B + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: calyx.enable @C + // CHECK-NEXT: calyx.enable @D + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + calyx.control { + calyx.seq { + calyx.if %eq.out with @Cond { + calyx.par { + calyx.enable @A + calyx.enable @C + calyx.enable @D + } + } else { + calyx.par { + calyx.enable @B + calyx.enable @C + calyx.enable @D + } + } + } + } + } +}