[RemoveVariableBound] factor out applyRemoveVariableBound() method (#20)
This commit is contained in:
parent
18260436aa
commit
94d6d57dda
|
@ -5,6 +5,7 @@
|
|||
#ifndef SCALEHLS_TRANSFORMS_PASSES_H
|
||||
#define SCALEHLS_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
|
@ -15,6 +16,9 @@ class Pass;
|
|||
namespace mlir {
|
||||
namespace scalehls {
|
||||
|
||||
/// Optimization APIs.
|
||||
bool applyRemoveVariableBound(AffineForOp loop, OpBuilder &builder);
|
||||
|
||||
/// Pragma optimization passes.
|
||||
std::unique_ptr<Pass> createLoopPipeliningPass();
|
||||
std::unique_ptr<Pass> createArrayPartitionPass();
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
|
||||
#include "Analysis/Utils.h"
|
||||
#include "Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
|
||||
|
@ -13,21 +12,32 @@ using namespace mlir;
|
|||
using namespace scalehls;
|
||||
|
||||
namespace {
|
||||
struct RemoveVariableBound : public RemoveVariableBoundBase<RemoveVariableBound> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void RemoveVariableBound::runOnOperation() {
|
||||
struct RemoveVariableBound
|
||||
: public RemoveVariableBoundBase<RemoveVariableBound> {
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto builder = OpBuilder(func);
|
||||
|
||||
// Walk through all functions and loops.
|
||||
for (auto forOp : func.getOps<mlir::AffineForOp>()) {
|
||||
SmallVector<mlir::AffineForOp, 4> nestedLoops;
|
||||
// TODO: support imperfect loops.
|
||||
getPerfectlyNestedLoops(nestedLoops, forOp);
|
||||
for (auto loop : func.getOps<AffineForOp>())
|
||||
applyRemoveVariableBound(loop, builder);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool scalehls::applyRemoveVariableBound(AffineForOp loop, OpBuilder &builder) {
|
||||
SmallVector<AffineForOp, 4> nestedLoops;
|
||||
getPerfectlyNestedLoops(nestedLoops, loop);
|
||||
|
||||
// Recursively apply remove variable bound for all child loops of the
|
||||
// innermost loop of nestedLoops.
|
||||
for (auto childLoop : nestedLoops.back().getOps<AffineForOp>())
|
||||
if (applyRemoveVariableBound(childLoop, builder))
|
||||
continue;
|
||||
else
|
||||
return false;
|
||||
|
||||
// Remove all vairable loop bound if possible.
|
||||
for (auto loop : nestedLoops) {
|
||||
// TODO: support remove variable lower bound.
|
||||
if (!loop.hasConstantUpperBound()) {
|
||||
|
@ -39,23 +49,22 @@ void RemoveVariableBound::runOnOperation() {
|
|||
auto upperMap = loop.getUpperBoundMap();
|
||||
auto ifExpr = upperMap.getResult(0) -
|
||||
builder.getAffineDimExpr(upperMap.getNumDims()) - 1;
|
||||
auto ifCondition = IntegerSet::get(upperMap.getNumDims() + 1, 0,
|
||||
ifExpr, /*eqFlags=*/false);
|
||||
auto ifCondition = IntegerSet::get(upperMap.getNumDims() + 1, 0, ifExpr,
|
||||
/*eqFlags=*/false);
|
||||
auto ifOperands = SmallVector<Value, 4>(loop.getUpperBoundOperands());
|
||||
ifOperands.push_back(loop.getInductionVar());
|
||||
|
||||
// Create if operation in the front of the innermost perfect loop.
|
||||
builder.setInsertionPointToStart(nestedLoops.back().getBody());
|
||||
auto ifOp = builder.create<mlir::AffineIfOp>(
|
||||
func.getLoc(), ifCondition, ifOperands,
|
||||
auto ifOp =
|
||||
builder.create<AffineIfOp>(loop.getLoc(), ifCondition, ifOperands,
|
||||
/*withElseRegion=*/false);
|
||||
|
||||
// Move all operations in the innermost perfect loop into the
|
||||
// new created AffineIf region.
|
||||
// Move all operations in the innermost perfect loop into the new
|
||||
// created AffineIf region.
|
||||
auto &ifBlock = ifOp.getBody()->getOperations();
|
||||
auto &loopBlock = nestedLoops.back().getBody()->getOperations();
|
||||
ifBlock.splice(ifBlock.begin(), loopBlock,
|
||||
std::next(loopBlock.begin()),
|
||||
ifBlock.splice(ifBlock.begin(), loopBlock, std::next(loopBlock.begin()),
|
||||
std::prev(loopBlock.end(), 1));
|
||||
|
||||
// Set constant variable bound.
|
||||
|
@ -64,7 +73,9 @@ void RemoveVariableBound::runOnOperation() {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For now, this method will always success.
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::Pass> scalehls::createRemoveVariableBoundPass() {
|
||||
|
|
Loading…
Reference in New Issue