[HLSCpp] Add StreamChannel/Read/Write/BufferOp; Rename Mul/CastPrimOp to PrimMul/CastOp; Rename AssignOp to BufferOp

This commit is contained in:
Hanchen Ye 2022-03-19 16:09:40 -05:00
parent 9f8df6f437
commit 93ebdb5412
14 changed files with 114 additions and 76 deletions

View File

@ -9,11 +9,49 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
// Stream Operations
//===----------------------------------------------------------------------===//
class StreamOf<list<Type> allowedTypes> :
Type<And<[CPred<"$_self.isa<::mlir::scalehls::hlscpp::StreamType>()">,
Concat<"[](::mlir::Type elementType) { return ",
SubstLeaves<"$_self", "elementType",
AnyTypeOf<allowedTypes>.predicate>,
"; }($_self.cast<::mlir::scalehls::hlscpp::StreamType>()"
".getElementType())">]>>;
def StreamChannelOp : HLSCppOp<"stream.channel"> {
let summary = "Stream channel declaration operation";
let results = (outs StreamOf<[AnyType]>:$channel);
}
def StreamReadOp : HLSCppOp<"stream.read"> {
let summary = "Stream channel read operation";
let arguments = (ins StreamOf<[AnyType]>:$channel);
let results = (outs Optional<AnyType>:$result);
}
def StreamWriteOp : HLSCppOp<"stream.write"> {
let summary = "Stream channel write operation";
let arguments = (ins StreamOf<[AnyType]>:$channel, Optional<AnyType>:$value);
}
def StreamBufferOp : HLSCppOp<"stream.buffer"> {
let summary = "Stream channel buffer operation";
let arguments = (ins StreamOf<[AnyType]>:$input);
let results = (outs StreamOf<[AnyType]>:$output);
}
//===----------------------------------------------------------------------===//
// Primitive Operations
//===----------------------------------------------------------------------===//
def MulPrimOp : HLSCppOp<"mul_prim", [NoSideEffect]> {
def PrimMulOp : HLSCppOp<"prim.mul", [NoSideEffect]> {
let summary = "Multiplication primitive operation";
let description = [{
This primitive performs C = A * B, where A and B are 8-bits integers, while
@ -39,31 +77,28 @@ def MulPrimOp : HLSCppOp<"mul_prim", [NoSideEffect]> {
let extraClassDeclaration = [{ bool isPackMul(); }];
}
def CastPrimOp : HLSCppOp<"cast_prim",
def PrimCastOp : HLSCppOp<"prim.cast",
[SameOperandsAndResultShape, NoSideEffect]> {
let summary = "Cast primitive operation";
let hasCanonicalizer = 1;
let arguments = (ins
AnyTypeOf<[I8, I16, I32, VectorOfLengthAndType<[2], [I8, I16, I32]>]>:$in
AnyTypeOf<[I8, I16, I32,
VectorOfLengthAndType<[2], [I8, I16, I32]>]>:$input
);
let results = (outs
AnyTypeOf<[I8, I16, I32, VectorOfLengthAndType<[2], [I8, I16, I32]>]>:$out
AnyTypeOf<[I8, I16, I32,
VectorOfLengthAndType<[2], [I8, I16, I32]>]>:$output
);
}
def AssignOp : HLSCppOp<"assign",
def BufferOp : HLSCppOp<"buffer",
[SameOperandsAndResultElementType, NoSideEffect]> {
let summary = "Assign the input value to the output";
let description = [{
This hlscpp.assign operation assigns the input value to the output, and can
be inserted anywhere without changing the original semantics. This is useful
for EmitHLSCpp to handle some corner cases and for tensor copy.
}];
let summary = "Buffer the input value";
let hasCanonicalizer = 1;
let arguments = (ins AnyType : $input);
let results = (outs AnyType : $output);
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
}
#endif // SCALEHLS_DIALECT_HLSCPP_OPS_TD

View File

@ -15,9 +15,12 @@ def StreamType : HLSCppType<"Stream"> {
}];
let mnemonic = "stream";
let parameters = (ins "::mlir::Type":$elementType);
let parameters = (ins
"::mlir::Type":$elementType,
"int64_t":$depth
);
let assemblyFormat = "`<` qualified($elementType) `>`";
let assemblyFormat = "`<` qualified($elementType) `,` $depth `>`";
}
#endif // SCALEHLS_DIALECT_HLSCPP_TYPES_TD

View File

@ -53,7 +53,7 @@ public:
bufferization::ToMemrefOp, bufferization::ToTensorOp,
// HLSCpp primitive operations.
MulPrimOp, CastPrimOp, AssignOp,
PrimMulOp, PrimCastOp, BufferOp,
// Control flow operations.
func::CallOp, func::ReturnOp,
@ -146,9 +146,9 @@ public:
HANDLE(bufferization::ToTensorOp);
// HLSCpp primitive operations.
HANDLE(MulPrimOp);
HANDLE(CastPrimOp);
HANDLE(AssignOp);
HANDLE(PrimMulOp);
HANDLE(PrimCastOp);
HANDLE(BufferOp);
// Control flow operations.
HANDLE(func::CallOp);

View File

@ -35,10 +35,10 @@ void HLSCppDialect::initialize() {
}
//===----------------------------------------------------------------------===//
// MulPrimOp
// PrimMulOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(MulPrimOp op) {
static LogicalResult verify(PrimMulOp op) {
auto AIsVector = op.A().getType().isa<VectorType>();
auto BIsVector = op.B().getType().isa<VectorType>();
auto CIsVector = op.C().getType().isa<VectorType>();
@ -50,7 +50,7 @@ static LogicalResult verify(MulPrimOp op) {
return failure();
}
bool MulPrimOp::isPackMul() {
bool PrimMulOp::isPackMul() {
auto AIsVector = A().getType().isa<VectorType>();
auto BIsVector = B().getType().isa<VectorType>();
return (AIsVector && !BIsVector) || (!AIsVector && BIsVector);
@ -194,22 +194,22 @@ void FuncDirectiveAttr::print(AsmPrinter &p) const {
//===----------------------------------------------------------------------===//
namespace {
struct SimplifyCastPrimOp : public OpRewritePattern<CastPrimOp> {
using OpRewritePattern<CastPrimOp>::OpRewritePattern;
struct SimplifyPrimCastOp : public OpRewritePattern<PrimCastOp> {
using OpRewritePattern<PrimCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CastPrimOp cast,
LogicalResult matchAndRewrite(PrimCastOp cast,
PatternRewriter &rewriter) const override {
if (cast.in().getType() == cast.out().getType()) {
rewriter.replaceOp(cast, cast.in());
if (cast.input().getType() == cast.output().getType()) {
rewriter.replaceOp(cast, cast.input());
return success();
}
// If the input of the cast is defined by another cast, then the two casts
// can be merged into one.
if (cast.in().hasOneUse())
if (auto defCast = cast.in().getDefiningOp<CastPrimOp>()) {
rewriter.replaceOpWithNewOp<CastPrimOp>(cast, cast.getType(),
defCast.in());
if (cast.input().hasOneUse())
if (auto defCast = cast.input().getDefiningOp<PrimCastOp>()) {
rewriter.replaceOpWithNewOp<PrimCastOp>(cast, cast.getType(),
defCast.input());
return success();
}
@ -218,19 +218,19 @@ struct SimplifyCastPrimOp : public OpRewritePattern<CastPrimOp> {
};
} // namespace
void CastPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
void PrimCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyCastPrimOp>(context);
results.add<SimplifyPrimCastOp>(context);
}
namespace {
struct SimplifyAssignOp : public OpRewritePattern<AssignOp> {
using OpRewritePattern<AssignOp>::OpRewritePattern;
struct SimplifyBufferOp : public OpRewritePattern<BufferOp> {
using OpRewritePattern<BufferOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AssignOp assign,
LogicalResult matchAndRewrite(BufferOp buffer,
PatternRewriter &rewriter) const override {
if (auto defOp = assign.input().getDefiningOp<AssignOp>()) {
assign.inputMutable().assign(defOp.input());
if (auto defOp = buffer.input().getDefiningOp<BufferOp>()) {
buffer.inputMutable().assign(defOp.input());
return success();
}
return failure();
@ -238,9 +238,9 @@ struct SimplifyAssignOp : public OpRewritePattern<AssignOp> {
};
} // namespace
void AssignOp::getCanonicalizationPatterns(RewritePatternSet &results,
void BufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyAssignOp>(context);
results.add<SimplifyBufferOp>(context);
}
//===----------------------------------------------------------------------===//
// Include tablegen classes

View File

@ -39,15 +39,15 @@ struct AddOpRewritePattern : public OpRewritePattern<arith::AddIOp> {
// Cast add op operand from the new type.
auto loc = add.getLoc();
rewriter.setInsertionPoint(add);
auto newLhs = rewriter.create<CastPrimOp>(loc, newType, add.getLhs());
auto newRhs = rewriter.create<CastPrimOp>(loc, newType, add.getRhs());
auto newLhs = rewriter.create<PrimCastOp>(loc, newType, add.getLhs());
auto newRhs = rewriter.create<PrimCastOp>(loc, newType, add.getRhs());
add.getLhsMutable().assign(newLhs);
add.getRhsMutable().assign(newRhs);
// Cast add op result to the new type.
rewriter.setInsertionPointAfter(add);
auto cast =
rewriter.create<CastPrimOp>(loc, add.getType(), add.getResult());
rewriter.create<PrimCastOp>(loc, add.getType(), add.getResult());
add.getResult().replaceAllUsesExcept(cast.getResult(), cast);
add.getResult().setType(newType);
@ -87,8 +87,8 @@ struct MulOpRewritePattern : public OpRewritePattern<arith::MulIOp> {
// Replace the original op with multiplication primitive op.
auto loc = mul.getLoc();
rewriter.setInsertionPoint(mul);
auto mulResult = rewriter.create<MulPrimOp>(loc, newType, lhs, rhs);
auto cast = rewriter.create<CastPrimOp>(loc, mul.getType(), mulResult);
auto mulResult = rewriter.create<PrimMulOp>(loc, newType, lhs, rhs);
auto cast = rewriter.create<PrimCastOp>(loc, mul.getType(), mulResult);
rewriter.replaceOp(mul, cast.getResult());
return success();

View File

@ -171,7 +171,7 @@ static bool applyLegalizeDataflow(Block &block, int64_t gran, bool balance) {
copyOp = builder.create<memref::CopyOp>(op->getLoc(), values.back(),
newValue);
} else {
copyOp = builder.create<hlscpp::AssignOp>(
copyOp = builder.create<hlscpp::BufferOp>(
op->getLoc(), value.getType(), values.back());
newValue = copyOp->getResult(0);
}

View File

@ -88,18 +88,18 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc,
func.getResultTypes()));
}
// Insert AssignOp when an arguments or result of ConstantOp are directly
// Insert BufferOp when an arguments or result of ConstantOp are directly
// connected to ReturnOp.
auto returnOp = func.front().getTerminator();
builder.setInsertionPoint(returnOp);
unsigned idx = 0;
for (auto operand : returnOp->getOperands()) {
if (operand.dyn_cast<BlockArgument>()) {
auto value = builder.create<AssignOp>(returnOp->getLoc(),
auto value = builder.create<BufferOp>(returnOp->getLoc(),
operand.getType(), operand);
returnOp->setOperand(idx, value);
} else if (isa<arith::ConstantOp>(operand.getDefiningOp())) {
auto value = builder.create<AssignOp>(returnOp->getLoc(),
auto value = builder.create<BufferOp>(returnOp->getLoc(),
operand.getType(), operand);
returnOp->setOperand(idx, value);
}

View File

@ -73,10 +73,10 @@ private:
} // namespace
namespace {
struct AssignOpRewritePattern : public OpRewritePattern<AssignOp> {
using OpRewritePattern<AssignOp>::OpRewritePattern;
struct BufferOpRewritePattern : public OpRewritePattern<BufferOp> {
using OpRewritePattern<BufferOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AssignOp assign,
LogicalResult matchAndRewrite(BufferOp assign,
PatternRewriter &rewriter) const override {
if (!assign->hasOneUse())
return failure();
@ -233,7 +233,7 @@ struct ConvertCopyToAffineLoops
// Simplify alloc and copy ops.
mlir::RewritePatternSet patterns(context);
patterns.add<AllocOpRewritePattern>(context, DT);
patterns.add<AssignOpRewritePattern>(context);
patterns.add<BufferOpRewritePattern>(context);
patterns.add<ReshapeOpLoweringPattern>(context);
(void)applyPatternsAndFoldGreedily(module, std::move(patterns));

View File

@ -225,7 +225,7 @@ public:
void emitMemrefToTensor(bufferization::ToTensorOp op);
/// HLSCpp primitive operation emitters.
void emitMulPrim(MulPrimOp op);
void emitPrimMul(PrimMulOp op);
template <typename AssignOpType> void emitAssign(AssignOpType op);
/// Control flow operation emitters.
@ -423,9 +423,9 @@ public:
return emitter.emitMemrefToTensor(op), true;
}
/// HLSCpp primitive operations.
bool visitOp(MulPrimOp op) { return emitter.emitMulPrim(op), true; }
bool visitOp(CastPrimOp op) { return emitter.emitAssign(op), true; }
bool visitOp(AssignOp op) { return emitter.emitAssign(op), true; }
bool visitOp(PrimMulOp op) { return emitter.emitPrimMul(op), true; }
bool visitOp(PrimCastOp op) { return emitter.emitAssign(op), true; }
bool visitOp(BufferOp op) { return emitter.emitAssign(op), true; }
/// Control flow operations.
bool visitOp(func::CallOp op) { return emitter.emitCall(op), true; }
@ -1268,7 +1268,7 @@ void ModuleEmitter::emitMemrefToTensor(bufferization::ToTensorOp op) {
}
/// HLSCpp primitive operation emitters.
void ModuleEmitter::emitMulPrim(MulPrimOp op) {
void ModuleEmitter::emitPrimMul(PrimMulOp op) {
if (op.isPackMul()) {
// Declare the result C array.
if (!isDeclared(op.C())) {
@ -1825,7 +1825,7 @@ using namespace std;
)XXX";
// Emit the multiplication primitive if required.
if (module.walk([](MulPrimOp op) {
if (module.walk([](PrimMulOp op) {
return op.isPackMul() ? WalkResult::interrupt() : WalkResult::advance();
}) == WalkResult::interrupt())
os << R"XXX(

View File

@ -21,14 +21,14 @@ module {
%2 = vector.transfer_read %arg1[%arg7, %arg8, %arg9, %arg6], %c0_i8 : memref<3x3x64x64xi8, #map1>, vector<2xi8>
%3 = vector.transfer_read %arg2[%arg3, %arg4, %arg5, %arg6], %c0_i8 : memref<1x32x32x64xi8, #map2>, vector<2xi8>
// CHECK: %3 = "hlscpp.mul_prim"(%0, %1) : (i8, vector<2xi8>) -> vector<2xi16>
// CHECK: %4 = "hlscpp.cast_prim"(%3) : (vector<2xi16>) -> vector<2xi8>
// CHECK: %3 = "hlscpp.prim.mul"(%0, %1) : (i8, vector<2xi8>) -> vector<2xi16>
// CHECK: %4 = "hlscpp.prim.cast"(%3) : (vector<2xi16>) -> vector<2xi8>
%4 = arith.muli %1, %2 : vector<2xi8>
// CHECK: %5 = "hlscpp.cast_prim"(%2) : (vector<2xi8>) -> vector<2xi32>
// CHECK: %6 = "hlscpp.cast_prim"(%4) : (vector<2xi8>) -> vector<2xi32>
// CHECK: %5 = "hlscpp.prim.cast"(%2) : (vector<2xi8>) -> vector<2xi32>
// CHECK: %6 = "hlscpp.prim.cast"(%4) : (vector<2xi8>) -> vector<2xi32>
// CHECK: %7 = arith.addi %5, %6 : vector<2xi32>
// CHECK: %8 = "hlscpp.cast_prim"(%7) : (vector<2xi32>) -> vector<2xi8>
// CHECK: %8 = "hlscpp.prim.cast"(%7) : (vector<2xi32>) -> vector<2xi8>
%5 = arith.addi %3, %4 : vector<2xi8>
vector.transfer_write %5, %arg2[%arg3, %arg4, %arg5, %arg6] : vector<2xi8>, memref<1x32x32x64xi8, #map2>
%6 = affine.apply #map3(%arg4)

View File

@ -12,8 +12,8 @@ module {
// CHECK: %2 = "tosa.clamp"
// CHECK: %3 = "tosa.conv2d"
// CHECK: %4 = "tosa.clamp"
// CHECK: %5 = "hlscpp.assign"
// CHECK: %6 = "hlscpp.assign"
// CHECK: %5 = "hlscpp.buffer"
// CHECK: %6 = "hlscpp.buffer"
// CHECK: return %4, %6
// CHECK: }
@ -26,7 +26,7 @@ module {
// CHECK: func @dataflow3(%arg0: tensor<1x32x32x64xi8>, %arg1: tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> {
// CHECK: %2 = "tosa.conv2d"
// CHECK: %3 = "hlscpp.assign"
// CHECK: %3 = "hlscpp.buffer"
// CHECK: %4 = "tosa.add"
// CHECK: %5 = "tosa.clamp"
// CHECK: return %5

View File

@ -28,11 +28,11 @@ module {
}
// CHECK-NOT: %[[VAL2:.*]] = bufferization.to_tensor %[[VAL0:.*]] : memref<1x32x32x64xi8>
// CHECK-NOT: %[[VAL3:.*]] = "hlscpp.assign"(%[[VAL2:.*]]) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
// CHECK-NOT: %[[VAL3:.*]] = "hlscpp.buffer"(%[[VAL2:.*]]) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
// CHECK-NOT: %[[VAL4:.*]] = bufferization.to_memref %[[VAL3:.*]] : memref<1x32x32x64xi8>
// CHECK-NOT: memref.copy %[[VAL4:.*]], %arg1 : memref<1x32x32x64xi8> to memref<1x32x32x64xi8>
%2 = bufferization.to_tensor %1 : memref<1x32x32x64xi8>
%10 = "hlscpp.assign"(%2) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
%10 = "hlscpp.buffer"(%2) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
%11 = bufferization.to_memref %10 : memref<1x32x32x64xi8>
memref.copy %11, %arg1 : memref<1x32x32x64xi8> to memref<1x32x32x64xi8>
return

View File

@ -18,7 +18,7 @@ module {
%0 = "tosa.clamp"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
%1 = "tosa.conv2d"(%0, %cst, %cst_0) {dilation = [1, 1], pad = [1, 1, 1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}, stride = [1, 1]} : (tensor<1x32x32x64xi8>, tensor<64x3x3x64xi8>, tensor<64xi8>) -> tensor<1x32x32x64xi8>
%2 = "tosa.clamp"(%1) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
%3 = "hlscpp.assign"(%0) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
%3 = "hlscpp.buffer"(%0) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
return %2, %3 : tensor<1x32x32x64xi8>, tensor<1x32x32x64xi8>
}
@ -40,7 +40,7 @@ module {
%cst = arith.constant dense<2> : tensor<64x3x3x64xi8>
%cst_0 = arith.constant dense<5> : tensor<64xi8>
%0 = "tosa.conv2d"(%arg0, %cst, %cst_0) {dilation = [1, 1], pad = [1, 1, 1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}, stride = [1, 1]} : (tensor<1x32x32x64xi8>, tensor<64x3x3x64xi8>, tensor<64xi8>) -> tensor<1x32x32x64xi8>
%1 = "hlscpp.assign"(%arg1) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
%1 = "hlscpp.buffer"(%arg1) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
%2 = "tosa.add"(%0, %1) : (tensor<1x32x32x64xi8>, tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
%3 = "tosa.clamp"(%2) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8>
return %3 : tensor<1x32x32x64xi8>

View File

@ -46,16 +46,16 @@ module {
// -----
// CHECK-LABEL: func @test_assign(
// CHECK-LABEL: func @test_buffer(
// CHECK-SAME: %arg0: f32, %arg1: memref<16xf32, 1>) -> (f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1>)
func @test_assign(%arg0: f32, %arg1: memref<16xf32, 1>) -> (f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1>) {
func @test_buffer(%arg0: f32, %arg1: memref<16xf32, 1>) -> (f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1>) {
%c11_i32 = arith.constant 11 : i32
%cst = arith.constant dense<[[11, 0], [0, -42]]> : tensor<2x2xi32>
%cst_memref = bufferization.to_memref %cst : memref<2x2xi32, 1>
// CHECK: %1 = "hlscpp.assign"(%arg0) : (f32) -> f32
// CHECK: %2 = "hlscpp.assign"(%arg1) : (memref<16xf32, 1>) -> memref<16xf32, 1>
// CHECK: %3 = "hlscpp.assign"(%c11_i32) : (i32) -> i32
// CHECK: %1 = "hlscpp.buffer"(%arg0) : (f32) -> f32
// CHECK: %2 = "hlscpp.buffer"(%arg1) : (memref<16xf32, 1>) -> memref<16xf32, 1>
// CHECK: %3 = "hlscpp.buffer"(%c11_i32) : (i32) -> i32
// CHECK: return %1, %2, %3, %0 : f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1>
return %arg0, %arg1, %c11_i32, %cst_memref : f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1>
}