124 lines
4.2 KiB
C++
124 lines
4.2 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Copyright 2020-2021 The ScaleHLS Authors.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "scalehls/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
using namespace scalehls;
|
|
|
|
namespace {
|
|
struct ClampOpRewritePattern : public OpRewritePattern<tosa::ClampOp> {
|
|
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ClampOp clamp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto transpose = clamp.input().getDefiningOp<tosa::TransposeOp>();
|
|
if (!transpose)
|
|
return failure();
|
|
|
|
clamp.inputMutable().assign(transpose.input1());
|
|
clamp.output().setType(transpose.input1().getType());
|
|
|
|
rewriter.setInsertionPointAfter(clamp);
|
|
auto cloneTranspose = cast<tosa::TransposeOp>(rewriter.clone(*transpose));
|
|
|
|
clamp.output().replaceAllUsesWith(cloneTranspose.output());
|
|
cloneTranspose.input1Mutable().assign(clamp.output());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
/// A helper to get permuatation vector from value.
|
|
static SmallVector<int64_t, 6> getPermValues(Value perm) {
|
|
DenseIntElementsAttr permAttr;
|
|
if (!matchPattern(perm, m_Constant(&permAttr)))
|
|
return {};
|
|
|
|
return llvm::to_vector<6>(
|
|
llvm::map_range(permAttr.getValues<APInt>(),
|
|
[](const APInt &val) { return val.getSExtValue(); }));
|
|
}
|
|
|
|
namespace {
|
|
struct TransposeOpRewritePattern : public OpRewritePattern<tosa::TransposeOp> {
|
|
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::TransposeOp transpose,
|
|
PatternRewriter &rewriter) const override {
|
|
auto inputTranspose = transpose.input1().getDefiningOp<tosa::TransposeOp>();
|
|
if (!inputTranspose)
|
|
return failure();
|
|
|
|
auto permValues = getPermValues(transpose.perms());
|
|
auto inputPermValues = getPermValues(inputTranspose.perms());
|
|
assert(permValues.size() == inputPermValues.size() &&
|
|
"unexpected permutation values");
|
|
|
|
for (unsigned i = 0, e = permValues.size(); i < e; ++i)
|
|
if (inputPermValues[permValues[i]] != i)
|
|
return failure();
|
|
|
|
rewriter.replaceOp(transpose, inputTranspose.input1());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
/// TODO: Expand this to all binary elementwise operations.
|
|
struct AddOpRewritePattern : public OpRewritePattern<tosa::AddOp> {
|
|
using OpRewritePattern<tosa::AddOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::AddOp add,
|
|
PatternRewriter &rewriter) const override {
|
|
auto input1Transpose = add.input1().getDefiningOp<tosa::TransposeOp>();
|
|
auto input2Transpose = add.input2().getDefiningOp<tosa::TransposeOp>();
|
|
if (!input1Transpose || !input2Transpose)
|
|
return failure();
|
|
|
|
if (getPermValues(input1Transpose.perms()) !=
|
|
getPermValues(input2Transpose.perms()))
|
|
return failure();
|
|
|
|
add.input1Mutable().assign(input1Transpose.input1());
|
|
add.input2Mutable().assign(input2Transpose.input1());
|
|
add.output().setType(input1Transpose.input1().getType());
|
|
|
|
rewriter.setInsertionPointAfter(add);
|
|
auto cloneTranspose =
|
|
cast<tosa::TransposeOp>(rewriter.clone(*input1Transpose));
|
|
|
|
add.output().replaceAllUsesWith(cloneTranspose.output());
|
|
cloneTranspose.input1Mutable().assign(add.output());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
struct TosaSimplifyGraph : public TosaSimplifyGraphBase<TosaSimplifyGraph> {
|
|
void runOnOperation() override {
|
|
auto func = getOperation();
|
|
auto context = func.getContext();
|
|
|
|
mlir::RewritePatternSet patterns(context);
|
|
patterns.add<ClampOpRewritePattern>(context);
|
|
patterns.add<TransposeOpRewritePattern>(context);
|
|
patterns.add<AddOpRewritePattern>(context);
|
|
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> scalehls::createTosaSimplifyGraphPass() {
|
|
return std::make_unique<TosaSimplifyGraph>();
|
|
}
|