97 lines
3.7 KiB
C++
97 lines
3.7 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Copyright 2020-2021 The ScaleHLS Authors.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "scalehls/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
using namespace scalehls;
|
|
|
|
namespace {
|
|
/// From the semantics point of view, reshape should not introduce a redundant
|
|
/// memref copy. However, in HLS, a reinterpret-like statement will obstruct the
|
|
/// array partition of the on-chip memory. Therefore, we convert reshape to
|
|
/// explict linalg generic operation in this lowering.
|
|
struct ReshapeOpRewritePattern : public OpRewritePattern<tosa::ReshapeOp> {
|
|
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ReshapeOp reshape,
|
|
PatternRewriter &rewriter) const override {
|
|
rewriter.setInsertionPoint(reshape);
|
|
auto inputType = reshape.input1().getType().cast<TensorType>();
|
|
auto outputType = reshape.getType();
|
|
|
|
// A helper to get the memory access map.
|
|
auto getIndexMap = [&](TensorType type) {
|
|
unsigned rank = type.getRank();
|
|
SmallVector<AffineExpr, 4> exprs(rank, rewriter.getAffineDimExpr(0));
|
|
|
|
for (unsigned dim = 0; dim < rank; ++dim)
|
|
for (unsigned idx = 0; idx < dim; ++idx)
|
|
exprs[idx] = exprs[idx].floorDiv(type.getDimSize(dim));
|
|
|
|
for (unsigned idx = 0; idx < rank; ++idx) {
|
|
auto &expr = exprs[idx];
|
|
if (auto constantExpr = expr.dyn_cast<AffineDimExpr>())
|
|
if (outputType.getNumElements() <= type.getDimSize(idx))
|
|
continue;
|
|
expr = expr % type.getDimSize(idx);
|
|
}
|
|
return AffineMap::get(/*dimCount=*/1, 0, exprs, rewriter.getContext());
|
|
};
|
|
|
|
// Create linalg init tensor and generic operation.
|
|
auto init = rewriter.create<linalg::InitTensorOp>(
|
|
reshape.getLoc(), outputType.getShape(), outputType.getElementType());
|
|
auto generic = rewriter.replaceOpWithNewOp<linalg::GenericOp>(
|
|
reshape, TypeRange(outputType), ValueRange(reshape.input1()),
|
|
ValueRange(init.result()),
|
|
SmallVector<AffineMap>(
|
|
{getIndexMap(inputType), getIndexMap(outputType)}),
|
|
SmallVector<StringRef>({"parallel"}));
|
|
|
|
// Create the body of generic operation that directly yield the input
|
|
// argument as result.
|
|
auto entry = rewriter.createBlock(&generic.getBodyRegion());
|
|
auto arg = entry->addArgument(inputType.getElementType(), reshape.getLoc());
|
|
entry->addArgument(outputType.getElementType(), reshape.getLoc());
|
|
|
|
rewriter.setInsertionPointToEnd(entry);
|
|
rewriter.create<linalg::YieldOp>(reshape.getLoc(), arg);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
struct TosaToLinalgCleanup
|
|
: public TosaToLinalgCleanupBase<TosaToLinalgCleanup> {
|
|
void runOnOperation() override {
|
|
auto func = getOperation();
|
|
auto context = func.getContext();
|
|
|
|
ConversionTarget target(*context);
|
|
target.addIllegalOp<tensor::PadOp, tosa::ReshapeOp>();
|
|
target.addLegalOp<linalg::GenericOp, linalg::YieldOp, linalg::InitTensorOp,
|
|
linalg::FillOp, arith::ConstantOp>();
|
|
|
|
mlir::RewritePatternSet patterns(context);
|
|
patterns.add<ReshapeOpRewritePattern>(context);
|
|
patterns.add<linalg::PadOpTransformationPattern>(context);
|
|
|
|
if (failed(applyPartialConversion(func, target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> scalehls::createTosaToLinalgCleanupPass() {
|
|
return std::make_unique<TosaToLinalgCleanup>();
|
|
}
|