[CreateHLSCppPrimitive] Support multiplication packing; [LegalizeToHLSCpp] Lower permute transfer read/write; Raise memref load/store to affine load/store
This commit is contained in:
parent
b3c40d94bb
commit
cb37d037b4
|
@ -4,6 +4,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "scalehls/Transforms/Passes.h"
|
||||
|
@ -27,7 +28,7 @@ struct AddOpRewritePattern : public OpRewritePattern<arith::AddIOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
// Figure out whether the add op can be rewritten.
|
||||
auto dataType = getIntDataType(add.getType());
|
||||
if (!dataType || dataType.getWidth() == 32)
|
||||
if (!dataType || dataType.getWidth() == 32 || dataType.isSigned())
|
||||
return failure();
|
||||
|
||||
// Generate new type.
|
||||
|
@ -63,7 +64,7 @@ struct MulOpRewritePattern : public OpRewritePattern<arith::MulIOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
// Figure out whether the mul op can be rewritten.
|
||||
auto dataType = getIntDataType(mul.getType());
|
||||
if (!dataType || dataType.getWidth() != 8)
|
||||
if (!dataType || dataType.getWidth() != 8 || dataType.isSigned())
|
||||
return failure();
|
||||
|
||||
// Generate new type.
|
||||
|
@ -75,11 +76,18 @@ struct MulOpRewritePattern : public OpRewritePattern<arith::MulIOp> {
|
|||
IntegerType::get(rewriter.getContext(), 16));
|
||||
}
|
||||
|
||||
auto lhs = mul.getLhs();
|
||||
if (auto broadcast = lhs.getDefiningOp<vector::BroadcastOp>())
|
||||
lhs = broadcast.source();
|
||||
|
||||
auto rhs = mul.getRhs();
|
||||
if (auto broadcast = rhs.getDefiningOp<vector::BroadcastOp>())
|
||||
rhs = broadcast.source();
|
||||
|
||||
// Replace the original op with multiplication primitive op.
|
||||
auto loc = mul.getLoc();
|
||||
rewriter.setInsertionPoint(mul);
|
||||
auto mulResult =
|
||||
rewriter.create<MulPrimOp>(loc, newType, mul.getLhs(), mul.getRhs());
|
||||
auto mulResult = rewriter.create<MulPrimOp>(loc, newType, lhs, rhs);
|
||||
auto cast = rewriter.create<CastPrimOp>(loc, mul.getType(), mulResult);
|
||||
rewriter.replaceOp(mul, cast.getResult());
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "scalehls/Transforms/Passes.h"
|
||||
#include "scalehls/Transforms/Utils.h"
|
||||
|
@ -17,37 +17,39 @@ using namespace scalehls;
|
|||
using namespace hlscpp;
|
||||
|
||||
namespace {
|
||||
struct TransferReadConversionPattern
|
||||
: public OpConversionPattern<vector::TransferReadOp> {
|
||||
using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
|
||||
/// Simple memref load to affine load raising.
|
||||
struct MemrefLoadRewritePattern : public OpRewritePattern<memref::LoadOp> {
|
||||
using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!op.permutation_map().isMinorIdentity() ||
|
||||
!op.source().getType().isa<MemRefType>())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(op, op.getType(),
|
||||
op.source(), op.indices());
|
||||
return success();
|
||||
LogicalResult matchAndRewrite(memref::LoadOp load,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (llvm::all_of(load.getIndices(), [&](Value operand) {
|
||||
return isValidDim(operand) || isValidSymbol(operand);
|
||||
})) {
|
||||
rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.memref(),
|
||||
load.getIndices());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
struct TransferWriteConversionPattern
|
||||
: public OpConversionPattern<vector::TransferWriteOp> {
|
||||
using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
|
||||
/// Simple memref store to affine store raising.
|
||||
struct MemrefStoreRewritePattern : public OpRewritePattern<memref::StoreOp> {
|
||||
using OpRewritePattern<memref::StoreOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::TransferWriteOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!op.permutation_map().isMinorIdentity() ||
|
||||
!op.source().getType().isa<MemRefType>())
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(op, op.vector(),
|
||||
op.source(), op.indices());
|
||||
return success();
|
||||
LogicalResult matchAndRewrite(memref::StoreOp store,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (llvm::all_of(store.getIndices(), [&](Value operand) {
|
||||
return isValidDim(operand) || isValidSymbol(operand);
|
||||
})) {
|
||||
rewriter.replaceOpWithNewOp<AffineStoreOp>(
|
||||
store, store.value(), store.memref(), store.getIndices());
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -119,19 +121,11 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc) {
|
|||
++idx;
|
||||
}
|
||||
|
||||
RewritePatternSet patterns(func.getContext());
|
||||
ConversionTarget target(*func.getContext());
|
||||
|
||||
// TODO: Make sure the lowering is safe and thorough.
|
||||
// patterns.add<TransferReadConversionPattern>(patterns.getContext());
|
||||
// patterns.add<TransferWriteConversionPattern>(patterns.getContext());
|
||||
// target.addLegalOp<AffineVectorLoadOp>();
|
||||
// target.addLegalOp<AffineVectorStoreOp>();
|
||||
// (void)applyPartialConversion(func, target, std::move(patterns));
|
||||
|
||||
// patterns.clear();
|
||||
// vector::populateVectorTransferLoweringPatterns(patterns);
|
||||
// (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
mlir::RewritePatternSet patterns(func.getContext());
|
||||
patterns.add<MemrefLoadRewritePattern>(func.getContext());
|
||||
patterns.add<MemrefStoreRewritePattern>(func.getContext());
|
||||
vector::populateVectorTransferLoweringPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -1235,23 +1235,25 @@ void ModuleEmitter::emitReinterpretCast(memref::ReinterpretCastOp op) {
|
|||
/// HLSCpp primitive operation emitters.
|
||||
void ModuleEmitter::emitMulPrim(MulPrimOp op) {
|
||||
if (op.isPackMul()) {
|
||||
if (op.A().getType().isa<VectorType>()) {
|
||||
os << "pack_mul(";
|
||||
emitValue(op.A());
|
||||
os << ", ";
|
||||
emitValue(op.B());
|
||||
os << ", ";
|
||||
// Declare the result C array.
|
||||
if (!isDeclared(op.C())) {
|
||||
indent();
|
||||
emitArrayDecl(op.C());
|
||||
os << ";\n";
|
||||
|
||||
indent() << "#pragma HLS array_partition variable=";
|
||||
emitValue(op.C());
|
||||
os << ");";
|
||||
} else {
|
||||
os << "pack_mul(";
|
||||
emitValue(op.B());
|
||||
os << ", ";
|
||||
emitValue(op.A());
|
||||
os << ", ";
|
||||
emitValue(op.C());
|
||||
os << ");";
|
||||
os << " complete dim=0\n";
|
||||
}
|
||||
|
||||
auto AIsVector = op.A().getType().isa<VectorType>();
|
||||
indent() << "pack_mul(";
|
||||
emitValue(AIsVector ? op.A() : op.B());
|
||||
os << ", ";
|
||||
emitValue(AIsVector ? op.B() : op.A());
|
||||
os << ", ";
|
||||
emitValue(op.C());
|
||||
os << ");";
|
||||
emitInfoAndNewLine(op);
|
||||
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue