[Calyx] Add enhancements to common tail elimination for `IfOp`. (#1895)

This adds a pattern for when then and else regions are both ParOps. It also takes into consideration 
the parent region of the IfOp, since we don't want to unintentionally change the behavior when pulling 
out EnableOps. I also do not touch anything that's not immediately within the given ParOp's body; 
doing so may also change the behavior of the program. The case where the then and else regions are 
different, e.g. SeqOp and ParOp, should never pull out an EnableOp; it will always produce worse code.
This commit is contained in:
Chris Gyurgyik 2021-09-29 20:37:27 -07:00 committed by GitHub
parent 5f7c2f8c96
commit 729d9280ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 333 additions and 49 deletions

View File

@ -68,7 +68,7 @@ def IfOp : CalyxContainer<"if", [
let assemblyFormat = "$cond (`with` $groupName^)? $thenRegion (`else` $elseRegion^)? attr-dict";
let verifier = "return ::verify$cppClass(*this);";
let hasCanonicalizeMethod = true;
let hasCanonicalizer = true;
let extraClassDeclaration = [{
/// Checks whether the `then` body exists.
bool thenBodyExists() {

View File

@ -1416,56 +1416,141 @@ static EnableOp getLastEnableOp(SeqOp parent) {
return nullptr;
}
/// Removes common tail enable operations for sequential 'then'/'else'
/// branches inside an 'if' operation.
///
/// if %a with %A { if %a with %A {
/// seq { ... calyx.enable @B } seq { ... }
/// else { -> } else {
/// seq { ... calyx.enable @B } seq { ... }
/// } }
/// calyx.enable @B
static LogicalResult eliminateCommonTailEnable(IfOp ifOp,
PatternRewriter &rewriter) {
// Check if the branches exist.
if (!ifOp.thenBodyExists() || !ifOp.elseBodyExists())
return failure();
/// Returns a mapping of {enabled Group name, EnableOp} for all EnableOps within
/// the immediate ParOp's body.
static llvm::StringMap<EnableOp> getAllEnableOpsInImmediateBody(ParOp parent) {
llvm::StringMap<EnableOp> enables;
Block *body = parent.getBody();
for (EnableOp op : body->getOps<EnableOp>())
enables.insert(std::pair(op.groupName(), op));
auto &thenOpStructureOp = ifOp.getThenBody()->front();
auto &elseOpStructureOp = ifOp.getElseBody()->front();
// TODO(circt/#1861): ParOps have less restrictive conditions.
if (isa<ParOp>(thenOpStructureOp) || isa<ParOp>(elseOpStructureOp))
return failure();
// At this point, only sequential operations are valid inside the branches.
auto thenSeqOp = dyn_cast<SeqOp>(thenOpStructureOp);
auto elseSeqOp = dyn_cast<SeqOp>(elseOpStructureOp);
assert(thenSeqOp && elseSeqOp &&
"expected nested seq ops in both branches of a calyx.IfOp");
EnableOp lastThenEnableOp = getLastEnableOp(thenSeqOp);
EnableOp lastElseEnableOp = getLastEnableOp(elseSeqOp);
if (lastThenEnableOp == nullptr || lastElseEnableOp == nullptr)
return failure();
if (lastThenEnableOp.groupName() != lastElseEnableOp.groupName())
return failure();
// Erase both enable operations and add group enable operation after the
// shared IfOp parent.
rewriter.setInsertionPointAfter(ifOp);
rewriter.create<EnableOp>(ifOp.getLoc(), lastThenEnableOp.groupName());
rewriter.eraseOp(lastThenEnableOp);
rewriter.eraseOp(lastElseEnableOp);
return success();
return enables;
}
LogicalResult IfOp::canonicalize(IfOp ifOp, PatternRewriter &rewriter) {
if (succeeded(eliminateCommonTailEnable(ifOp, rewriter)))
return success();
/// Checks preconditions for the common tail pattern. This canonicalization is
/// stringent about not entering nested control operations, as this may cause
/// unintentional changes in behavior.
/// We only look for two cases: (1) both regions are ParOps, and
/// (2) both regions are SeqOps. The case when these are different, e.g. ParOp
/// and SeqOp, will only produce less optimal code, or even worse, change the
/// behavior.
template <typename OpTy>
static bool hasCommonTailPatternPreConditions(IfOp op) {
static_assert(std::is_same<SeqOp, OpTy>() || std::is_same<ParOp, OpTy>(),
"Should be a SeqOp or ParOp.");
return failure();
if (!op.thenBodyExists() || !op.elseBodyExists())
return false;
Block *thenBody = op.getThenBody(), *elseBody = op.getElseBody();
return isa<OpTy>(thenBody->front()) && isa<OpTy>(elseBody->front());
}
/// seq {
/// if %a with @G { if %a with @G {
/// seq { ... calyx.enable @A } seq { ... }
/// else { -> } else {
/// seq { ... calyx.enable @A } seq { ... }
/// } }
/// calyx.enable @A
/// }
struct CommonTailPatternWithSeq : mlir::OpRewritePattern<IfOp> {
using mlir::OpRewritePattern<IfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override {
if (!hasCommonTailPatternPreConditions<SeqOp>(ifOp))
return failure();
auto thenControl = cast<SeqOp>(ifOp.getThenBody()->front()),
elseControl = cast<SeqOp>(ifOp.getElseBody()->front());
EnableOp lastThenEnableOp = getLastEnableOp(thenControl),
lastElseEnableOp = getLastEnableOp(elseControl);
if (lastThenEnableOp == nullptr || lastElseEnableOp == nullptr)
return failure();
if (lastThenEnableOp.groupName() != lastElseEnableOp.groupName())
return failure();
// Place the IfOp and pulled EnableOp inside a sequential region, in case
// this IfOp is nested in a ParOp. This avoids unintentionally
// parallelizing the pulled out EnableOps.
rewriter.setInsertionPointAfter(ifOp);
SeqOp seqOp = rewriter.create<SeqOp>(ifOp.getLoc());
rewriter.createBlock(&seqOp.getBodyRegion());
Block *body = seqOp.getBody();
ifOp->remove();
body->push_back(ifOp);
rewriter.create<EnableOp>(seqOp.getLoc(), lastThenEnableOp.groupName());
// Erase the common EnableOp from the Then and Else regions.
rewriter.eraseOp(lastThenEnableOp);
rewriter.eraseOp(lastElseEnableOp);
return success();
}
};
/// if %a with @G { par {
/// par { if %a with @G {
/// ... par { ... }
/// calyx.enable @A } else {
/// calyx.enable @B -> par { ... }
/// } }
/// } else { calyx.enable @A
/// par { calyx.enable @B
/// ... }
/// calyx.enable @A
/// calyx.enable @B
/// }
/// }
struct CommonTailPatternWithPar : mlir::OpRewritePattern<IfOp> {
using mlir::OpRewritePattern<IfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override {
if (!hasCommonTailPatternPreConditions<ParOp>(ifOp))
return failure();
auto thenControl = cast<ParOp>(ifOp.getThenBody()->front()),
elseControl = cast<ParOp>(ifOp.getElseBody()->front());
llvm::StringMap<EnableOp> A = getAllEnableOpsInImmediateBody(thenControl),
B = getAllEnableOpsInImmediateBody(elseControl);
// Compute the intersection between `A` and `B`.
SmallVector<StringRef> groupNames;
for (auto a = A.begin(); a != A.end(); ++a) {
StringRef groupName = a->getKey();
auto b = B.find(groupName);
if (b == B.end())
continue;
// This is also an element in B.
groupNames.push_back(groupName);
// Since these are being pulled out, erase them.
rewriter.eraseOp(a->getValue());
rewriter.eraseOp(b->getValue());
}
// Place the IfOp and EnableOp(s) inside a parallel region, in case this
// IfOp is nested in a SeqOp. This avoids unintentionally sequentializing
// the pulled out EnableOps.
rewriter.setInsertionPointAfter(ifOp);
ParOp parOp = rewriter.create<ParOp>(ifOp.getLoc());
rewriter.createBlock(&parOp.getBodyRegion());
Block *body = parOp.getBody();
ifOp->remove();
body->push_back(ifOp);
// Pull out the intersection between these two sets, and erase their
// counterparts in the Then and Else regions.
for (StringRef groupName : groupNames)
rewriter.create<EnableOp>(parOp.getLoc(), groupName);
return success();
}
};
void IfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<CommonTailPatternWithSeq, CommonTailPatternWithPar>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -54,7 +54,7 @@ calyx.program "main" {
}
}
// IfOp removes common tails from within SeqOps.
// IfOp nested in SeqOp removes common tail from within SeqOps.
calyx.program "main" {
calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%done: i1 {done}) {
%r.in, %r.write_en, %r.clk, %r.reset, %r.out, %r.done = calyx.register "r" : i1, i1, i1, i1, i1, i1
@ -87,7 +87,7 @@ calyx.program "main" {
// CHECK-NEXT: calyx.seq {
// CHECK-NEXT: calyx.enable @B
// CHECK-NEXT: }
// CHECK-NEXT: else {
// CHECK-NEXT: } else {
// CHECK-NEXT: calyx.seq {
// CHECK-NEXT: calyx.enable @C
// CHECK-NEXT: }
@ -112,3 +112,202 @@ calyx.program "main" {
}
}
}
// IfOp nested in ParOp removes common tails from within ParOps.
calyx.program "main" {
calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%done: i1 {done}) {
%r.in, %r.write_en, %r.clk, %r.reset, %r.out, %r.done = calyx.register "r" : i1, i1, i1, i1, i1, i1
%eq.left, %eq.right, %eq.out = calyx.std_eq "eq" : i1, i1, i1
%c1_1 = hw.constant 1 : i1
calyx.wires {
calyx.comb_group @Cond {
calyx.assign %eq.left = %c1_1 : i1
calyx.assign %eq.right = %c1_1 : i1
}
calyx.group @A {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
calyx.group @B {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
calyx.group @C {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
calyx.group @D {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
}
// CHECK-LABEL: calyx.control {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.if %eq.out with @Cond {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.enable @A
// CHECK-NEXT: }
// CHECK-NEXT: } else {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.enable @B
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: calyx.enable @C
// CHECK-NEXT: calyx.enable @D
// CHECK-NEXT: }
// CHECK-NEXT: }
calyx.control {
calyx.par {
calyx.if %eq.out with @Cond {
calyx.par {
calyx.enable @A
calyx.enable @C
calyx.enable @D
}
} else {
calyx.par {
calyx.enable @B
calyx.enable @C
calyx.enable @D
}
}
}
}
}
}
// IfOp nested in ParOp removes common tail from within SeqOps. The important check
// here is ensuring the removed EnableOps are still computed sequentially.
calyx.program "main" {
calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%done: i1 {done}) {
%r.in, %r.write_en, %r.clk, %r.reset, %r.out, %r.done = calyx.register "r" : i1, i1, i1, i1, i1, i1
%eq.left, %eq.right, %eq.out = calyx.std_eq "eq" : i1, i1, i1
%c1_1 = hw.constant 1 : i1
calyx.wires {
calyx.comb_group @Cond {
calyx.assign %eq.left = %c1_1 : i1
calyx.assign %eq.right = %c1_1 : i1
}
calyx.group @A {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
calyx.group @B {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
calyx.group @C {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
}
// CHECK-LABEL: calyx.control {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.seq {
// CHECK-NEXT: calyx.if %eq.out with @Cond {
// CHECK-NEXT: calyx.seq {
// CHECK-NEXT: calyx.enable @B
// CHECK-NEXT: }
// CHECK-NEXT: } else {
// CHECK-NEXT: calyx.seq {
// CHECK-NEXT: calyx.enable @C
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: calyx.enable @A
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
calyx.control {
calyx.par {
calyx.if %eq.out with @Cond {
calyx.seq {
calyx.enable @B
calyx.enable @A
}
} else {
calyx.seq {
calyx.enable @C
calyx.enable @A
}
}
}
}
}
}
// IfOp nested in SeqOp removes common tail from within ParOps. The important check
// here is ensuring the removed EnableOps are still computed in parallel.
calyx.program "main" {
calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%done: i1 {done}) {
%r.in, %r.write_en, %r.clk, %r.reset, %r.out, %r.done = calyx.register "r" : i1, i1, i1, i1, i1, i1
%eq.left, %eq.right, %eq.out = calyx.std_eq "eq" : i1, i1, i1
%c1_1 = hw.constant 1 : i1
calyx.wires {
calyx.comb_group @Cond {
calyx.assign %eq.left = %c1_1 : i1
calyx.assign %eq.right = %c1_1 : i1
}
calyx.group @A {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
calyx.group @B {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
calyx.group @C {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
calyx.group @D {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
}
// CHECK-LABEL: calyx.control {
// CHECK-NEXT: calyx.seq {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.if %eq.out with @Cond {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.enable @A
// CHECK-NEXT: }
// CHECK-NEXT: } else {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.enable @B
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: calyx.enable @C
// CHECK-NEXT: calyx.enable @D
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
calyx.control {
calyx.seq {
calyx.if %eq.out with @Cond {
calyx.par {
calyx.enable @A
calyx.enable @C
calyx.enable @D
}
} else {
calyx.par {
calyx.enable @B
calyx.enable @C
calyx.enable @D
}
}
}
}
}
}