[CreateHLSCppPrimitive] Support multiplication packing; [LegalizeToHLSCpp] Lower permute transfer read/write; Raise memref load/store to affine load/store

This commit is contained in:
Hanchen Ye 2022-02-17 15:28:59 -06:00
parent b3c40d94bb
commit cb37d037b4
3 changed files with 61 additions and 57 deletions

View File

@ -4,6 +4,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "scalehls/Transforms/Passes.h"
@ -27,7 +28,7 @@ struct AddOpRewritePattern : public OpRewritePattern<arith::AddIOp> {
PatternRewriter &rewriter) const override {
// Figure out whether the add op can be rewritten.
auto dataType = getIntDataType(add.getType());
if (!dataType || dataType.getWidth() == 32)
if (!dataType || dataType.getWidth() == 32 || dataType.isSigned())
return failure();
// Generate new type.
@ -63,7 +64,7 @@ struct MulOpRewritePattern : public OpRewritePattern<arith::MulIOp> {
PatternRewriter &rewriter) const override {
// Figure out whether the mul op can be rewritten.
auto dataType = getIntDataType(mul.getType());
if (!dataType || dataType.getWidth() != 8)
if (!dataType || dataType.getWidth() != 8 || dataType.isSigned())
return failure();
// Generate new type.
@ -75,11 +76,18 @@ struct MulOpRewritePattern : public OpRewritePattern<arith::MulIOp> {
IntegerType::get(rewriter.getContext(), 16));
}
auto lhs = mul.getLhs();
if (auto broadcast = lhs.getDefiningOp<vector::BroadcastOp>())
lhs = broadcast.source();
auto rhs = mul.getRhs();
if (auto broadcast = rhs.getDefiningOp<vector::BroadcastOp>())
rhs = broadcast.source();
// Replace the original op with multiplication primitive op.
auto loc = mul.getLoc();
rewriter.setInsertionPoint(mul);
auto mulResult =
rewriter.create<MulPrimOp>(loc, newType, mul.getLhs(), mul.getRhs());
auto mulResult = rewriter.create<MulPrimOp>(loc, newType, lhs, rhs);
auto cast = rewriter.create<CastPrimOp>(loc, mul.getType(), mulResult);
rewriter.replaceOp(mul, cast.getResult());

View File

@ -6,8 +6,8 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "scalehls/Transforms/Passes.h"
#include "scalehls/Transforms/Utils.h"
@ -17,37 +17,39 @@ using namespace scalehls;
using namespace hlscpp;
namespace {
struct TransferReadConversionPattern
: public OpConversionPattern<vector::TransferReadOp> {
using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
/// Simple memref load to affine load raising.
struct MemrefLoadRewritePattern : public OpRewritePattern<memref::LoadOp> {
using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
LogicalResult
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.permutation_map().isMinorIdentity() ||
!op.source().getType().isa<MemRefType>())
return failure();
rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(op, op.getType(),
op.source(), op.indices());
return success();
LogicalResult matchAndRewrite(memref::LoadOp load,
PatternRewriter &rewriter) const override {
if (llvm::all_of(load.getIndices(), [&](Value operand) {
return isValidDim(operand) || isValidSymbol(operand);
})) {
rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.memref(),
load.getIndices());
return success();
}
return failure();
}
};
} // namespace
namespace {
struct TransferWriteConversionPattern
: public OpConversionPattern<vector::TransferWriteOp> {
using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
/// Simple memref store to affine store raising.
struct MemrefStoreRewritePattern : public OpRewritePattern<memref::StoreOp> {
using OpRewritePattern<memref::StoreOp>::OpRewritePattern;
LogicalResult
matchAndRewrite(vector::TransferWriteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.permutation_map().isMinorIdentity() ||
!op.source().getType().isa<MemRefType>())
return failure();
rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(op, op.vector(),
op.source(), op.indices());
return success();
LogicalResult matchAndRewrite(memref::StoreOp store,
PatternRewriter &rewriter) const override {
if (llvm::all_of(store.getIndices(), [&](Value operand) {
return isValidDim(operand) || isValidSymbol(operand);
})) {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
store, store.value(), store.memref(), store.getIndices());
return success();
}
return failure();
}
};
} // namespace
@ -119,19 +121,11 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc) {
++idx;
}
RewritePatternSet patterns(func.getContext());
ConversionTarget target(*func.getContext());
// TODO: Make sure the lowering is safe and thorough.
// patterns.add<TransferReadConversionPattern>(patterns.getContext());
// patterns.add<TransferWriteConversionPattern>(patterns.getContext());
// target.addLegalOp<AffineVectorLoadOp>();
// target.addLegalOp<AffineVectorStoreOp>();
// (void)applyPartialConversion(func, target, std::move(patterns));
// patterns.clear();
// vector::populateVectorTransferLoweringPatterns(patterns);
// (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
mlir::RewritePatternSet patterns(func.getContext());
patterns.add<MemrefLoadRewritePattern>(func.getContext());
patterns.add<MemrefStoreRewritePattern>(func.getContext());
vector::populateVectorTransferLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
return true;
}

View File

@ -1235,23 +1235,25 @@ void ModuleEmitter::emitReinterpretCast(memref::ReinterpretCastOp op) {
/// HLSCpp primitive operation emitters.
void ModuleEmitter::emitMulPrim(MulPrimOp op) {
if (op.isPackMul()) {
if (op.A().getType().isa<VectorType>()) {
os << "pack_mul(";
emitValue(op.A());
os << ", ";
emitValue(op.B());
os << ", ";
// Declare the result C array.
if (!isDeclared(op.C())) {
indent();
emitArrayDecl(op.C());
os << ";\n";
indent() << "#pragma HLS array_partition variable=";
emitValue(op.C());
os << ");";
} else {
os << "pack_mul(";
emitValue(op.B());
os << ", ";
emitValue(op.A());
os << ", ";
emitValue(op.C());
os << ");";
os << " complete dim=0\n";
}
auto AIsVector = op.A().getType().isa<VectorType>();
indent() << "pack_mul(";
emitValue(AIsVector ? op.A() : op.B());
os << ", ";
emitValue(AIsVector ? op.B() : op.A());
os << ", ";
emitValue(op.C());
os << ");";
emitInfoAndNewLine(op);
} else {