[RemoveVariableBound] factor out applyRemoveVariableBound() method (#20)

This commit is contained in:
Hanchen Ye 2021-01-07 16:44:09 -06:00
parent 18260436aa
commit 94d6d57dda
2 changed files with 57 additions and 42 deletions

View File

@ -5,6 +5,7 @@
#ifndef SCALEHLS_TRANSFORMS_PASSES_H #ifndef SCALEHLS_TRANSFORMS_PASSES_H
#define SCALEHLS_TRANSFORMS_PASSES_H #define SCALEHLS_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory> #include <memory>
@ -15,6 +16,9 @@ class Pass;
namespace mlir { namespace mlir {
namespace scalehls { namespace scalehls {
/// Optimization APIs.
bool applyRemoveVariableBound(AffineForOp loop, OpBuilder &builder);
/// Pragma optimization passes. /// Pragma optimization passes.
std::unique_ptr<Pass> createLoopPipeliningPass(); std::unique_ptr<Pass> createLoopPipeliningPass();
std::unique_ptr<Pass> createArrayPartitionPass(); std::unique_ptr<Pass> createArrayPartitionPass();

View File

@ -4,7 +4,6 @@
#include "Analysis/Utils.h" #include "Analysis/Utils.h"
#include "Transforms/Passes.h" #include "Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/LoopUtils.h"
@ -13,21 +12,32 @@ using namespace mlir;
using namespace scalehls; using namespace scalehls;
namespace { namespace {
struct RemoveVariableBound : public RemoveVariableBoundBase<RemoveVariableBound> { struct RemoveVariableBound
void runOnOperation() override; : public RemoveVariableBoundBase<RemoveVariableBound> {
}; void runOnOperation() override {
} // namespace
void RemoveVariableBound::runOnOperation() {
auto func = getOperation(); auto func = getOperation();
auto builder = OpBuilder(func); auto builder = OpBuilder(func);
// Walk through all functions and loops. // Walk through all functions and loops.
for (auto forOp : func.getOps<mlir::AffineForOp>()) { for (auto loop : func.getOps<AffineForOp>())
SmallVector<mlir::AffineForOp, 4> nestedLoops; applyRemoveVariableBound(loop, builder);
// TODO: support imperfect loops. }
getPerfectlyNestedLoops(nestedLoops, forOp); };
} // namespace
bool scalehls::applyRemoveVariableBound(AffineForOp loop, OpBuilder &builder) {
SmallVector<AffineForOp, 4> nestedLoops;
getPerfectlyNestedLoops(nestedLoops, loop);
// Recursively apply remove variable bound for all child loops of the
// innermost loop of nestedLoops.
for (auto childLoop : nestedLoops.back().getOps<AffineForOp>())
if (applyRemoveVariableBound(childLoop, builder))
continue;
else
return false;
// Remove all vairable loop bound if possible.
for (auto loop : nestedLoops) { for (auto loop : nestedLoops) {
// TODO: support remove variable lower bound. // TODO: support remove variable lower bound.
if (!loop.hasConstantUpperBound()) { if (!loop.hasConstantUpperBound()) {
@ -39,23 +49,22 @@ void RemoveVariableBound::runOnOperation() {
auto upperMap = loop.getUpperBoundMap(); auto upperMap = loop.getUpperBoundMap();
auto ifExpr = upperMap.getResult(0) - auto ifExpr = upperMap.getResult(0) -
builder.getAffineDimExpr(upperMap.getNumDims()) - 1; builder.getAffineDimExpr(upperMap.getNumDims()) - 1;
auto ifCondition = IntegerSet::get(upperMap.getNumDims() + 1, 0, auto ifCondition = IntegerSet::get(upperMap.getNumDims() + 1, 0, ifExpr,
ifExpr, /*eqFlags=*/false); /*eqFlags=*/false);
auto ifOperands = SmallVector<Value, 4>(loop.getUpperBoundOperands()); auto ifOperands = SmallVector<Value, 4>(loop.getUpperBoundOperands());
ifOperands.push_back(loop.getInductionVar()); ifOperands.push_back(loop.getInductionVar());
// Create if operation in the front of the innermost perfect loop. // Create if operation in the front of the innermost perfect loop.
builder.setInsertionPointToStart(nestedLoops.back().getBody()); builder.setInsertionPointToStart(nestedLoops.back().getBody());
auto ifOp = builder.create<mlir::AffineIfOp>( auto ifOp =
func.getLoc(), ifCondition, ifOperands, builder.create<AffineIfOp>(loop.getLoc(), ifCondition, ifOperands,
/*withElseRegion=*/false); /*withElseRegion=*/false);
// Move all operations in the innermost perfect loop into the // Move all operations in the innermost perfect loop into the new
// new created AffineIf region. // created AffineIf region.
auto &ifBlock = ifOp.getBody()->getOperations(); auto &ifBlock = ifOp.getBody()->getOperations();
auto &loopBlock = nestedLoops.back().getBody()->getOperations(); auto &loopBlock = nestedLoops.back().getBody()->getOperations();
ifBlock.splice(ifBlock.begin(), loopBlock, ifBlock.splice(ifBlock.begin(), loopBlock, std::next(loopBlock.begin()),
std::next(loopBlock.begin()),
std::prev(loopBlock.end(), 1)); std::prev(loopBlock.end(), 1));
// Set constant variable bound. // Set constant variable bound.
@ -64,7 +73,9 @@ void RemoveVariableBound::runOnOperation() {
} }
} }
} }
}
// For now, this method will always success.
return true;
} }
std::unique_ptr<mlir::Pass> scalehls::createRemoveVariableBoundPass() { std::unique_ptr<mlir::Pass> scalehls::createRemoveVariableBoundPass() {