From 93ebdb54126ca21826794c6eda2c5d5a1fcffd25 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Sat, 19 Mar 2022 16:09:40 -0500 Subject: [PATCH] [HLSCpp] Add StreamChannel/Read/Write/BufferOp; Rename Mul/CastPrimOp to PrimMul/CastOp; Rename AssignOp to BufferOp --- include/scalehls/Dialect/HLSCpp/Ops.td | 61 +++++++++++++++---- include/scalehls/Dialect/HLSCpp/Types.td | 7 ++- include/scalehls/Dialect/HLSCpp/Visitor.h | 8 +-- lib/Dialect/HLSCpp/HLSCpp.cpp | 42 ++++++------- .../Directive/CreateHLSCppPrimitive.cpp | 10 +-- lib/Transforms/Graph/FuncDataflow.cpp | 2 +- lib/Transforms/LegalizeToHLSCpp.cpp | 6 +- .../Loop/ConvertCopyToAffineLoops.cpp | 8 +-- lib/Translation/EmitHLSCpp.cpp | 12 ++-- .../Directive/create_hlscpp_primitive.mlir | 10 +-- test/Transforms/Graph/func_dataflow.mlir | 6 +- .../Loop/convert_copy_to_affine_loops.mlir | 4 +- test/Transforms/create_runtime_main.mlir | 4 +- test/Transforms/legalize_to_hlscpp.mlir | 10 +-- 14 files changed, 114 insertions(+), 76 deletions(-) diff --git a/include/scalehls/Dialect/HLSCpp/Ops.td b/include/scalehls/Dialect/HLSCpp/Ops.td index cf92b69..8048854 100644 --- a/include/scalehls/Dialect/HLSCpp/Ops.td +++ b/include/scalehls/Dialect/HLSCpp/Ops.td @@ -9,11 +9,49 @@ include "mlir/Interfaces/SideEffectInterfaces.td" +//===----------------------------------------------------------------------===// +// Stream Operations +//===----------------------------------------------------------------------===// + +class StreamOf allowedTypes> : + Type()">, + Concat<"[](::mlir::Type elementType) { return ", + SubstLeaves<"$_self", "elementType", + AnyTypeOf.predicate>, + "; }($_self.cast<::mlir::scalehls::hlscpp::StreamType>()" + ".getElementType())">]>>; + +def StreamChannelOp : HLSCppOp<"stream.channel"> { + let summary = "Stream channel declaration operation"; + + let results = (outs StreamOf<[AnyType]>:$channel); +} + +def StreamReadOp : HLSCppOp<"stream.read"> { + let summary = "Stream channel read operation"; + + let arguments = (ins StreamOf<[AnyType]>:$channel); + let results = (outs Optional:$result); +} + +def StreamWriteOp : HLSCppOp<"stream.write"> { + let summary = "Stream channel write operation"; + + let arguments = (ins StreamOf<[AnyType]>:$channel, Optional:$value); +} + +def StreamBufferOp : HLSCppOp<"stream.buffer"> { + let summary = "Stream channel buffer operation"; + + let arguments = (ins StreamOf<[AnyType]>:$input); + let results = (outs StreamOf<[AnyType]>:$output); +} + //===----------------------------------------------------------------------===// // Primitive Operations //===----------------------------------------------------------------------===// -def MulPrimOp : HLSCppOp<"mul_prim", [NoSideEffect]> { +def PrimMulOp : HLSCppOp<"prim.mul", [NoSideEffect]> { let summary = "Multiplication primitive operation"; let description = [{ This primitive performs C = A * B, where A and B are 8-bits integers, while @@ -39,31 +77,28 @@ def MulPrimOp : HLSCppOp<"mul_prim", [NoSideEffect]> { let extraClassDeclaration = [{ bool isPackMul(); }]; } -def CastPrimOp : HLSCppOp<"cast_prim", +def PrimCastOp : HLSCppOp<"prim.cast", [SameOperandsAndResultShape, NoSideEffect]> { let summary = "Cast primitive operation"; let hasCanonicalizer = 1; let arguments = (ins - AnyTypeOf<[I8, I16, I32, VectorOfLengthAndType<[2], [I8, I16, I32]>]>:$in + AnyTypeOf<[I8, I16, I32, + VectorOfLengthAndType<[2], [I8, I16, I32]>]>:$input ); let results = (outs - AnyTypeOf<[I8, I16, I32, VectorOfLengthAndType<[2], [I8, I16, I32]>]>:$out + AnyTypeOf<[I8, I16, I32, + VectorOfLengthAndType<[2], [I8, I16, I32]>]>:$output ); } -def AssignOp : HLSCppOp<"assign", +def BufferOp : HLSCppOp<"buffer", [SameOperandsAndResultElementType, NoSideEffect]> { - let summary = "Assign the input value to the output"; - let description = [{ - This hlscpp.assign operation assigns the input value to the output, and can - be inserted anywhere without changing the original semantics. This is useful - for EmitHLSCpp to handle some corner cases and for tensor copy. - }]; + let summary = "Buffer the input value"; let hasCanonicalizer = 1; - let arguments = (ins AnyType : $input); - let results = (outs AnyType : $output); + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$output); } #endif // SCALEHLS_DIALECT_HLSCPP_OPS_TD diff --git a/include/scalehls/Dialect/HLSCpp/Types.td b/include/scalehls/Dialect/HLSCpp/Types.td index bf741d8..4c9769f 100644 --- a/include/scalehls/Dialect/HLSCpp/Types.td +++ b/include/scalehls/Dialect/HLSCpp/Types.td @@ -15,9 +15,12 @@ def StreamType : HLSCppType<"Stream"> { }]; let mnemonic = "stream"; - let parameters = (ins "::mlir::Type":$elementType); + let parameters = (ins + "::mlir::Type":$elementType, + "int64_t":$depth + ); - let assemblyFormat = "`<` qualified($elementType) `>`"; + let assemblyFormat = "`<` qualified($elementType) `,` $depth `>`"; } #endif // SCALEHLS_DIALECT_HLSCPP_TYPES_TD diff --git a/include/scalehls/Dialect/HLSCpp/Visitor.h b/include/scalehls/Dialect/HLSCpp/Visitor.h index 04e0b03..4d95068 100644 --- a/include/scalehls/Dialect/HLSCpp/Visitor.h +++ b/include/scalehls/Dialect/HLSCpp/Visitor.h @@ -53,7 +53,7 @@ public: bufferization::ToMemrefOp, bufferization::ToTensorOp, // HLSCpp primitive operations. - MulPrimOp, CastPrimOp, AssignOp, + PrimMulOp, PrimCastOp, BufferOp, // Control flow operations. func::CallOp, func::ReturnOp, @@ -146,9 +146,9 @@ public: HANDLE(bufferization::ToTensorOp); // HLSCpp primitive operations. - HANDLE(MulPrimOp); - HANDLE(CastPrimOp); - HANDLE(AssignOp); + HANDLE(PrimMulOp); + HANDLE(PrimCastOp); + HANDLE(BufferOp); // Control flow operations. HANDLE(func::CallOp); diff --git a/lib/Dialect/HLSCpp/HLSCpp.cpp b/lib/Dialect/HLSCpp/HLSCpp.cpp index a89f8f6..291637e 100644 --- a/lib/Dialect/HLSCpp/HLSCpp.cpp +++ b/lib/Dialect/HLSCpp/HLSCpp.cpp @@ -35,10 +35,10 @@ void HLSCppDialect::initialize() { } //===----------------------------------------------------------------------===// -// MulPrimOp +// PrimMulOp //===----------------------------------------------------------------------===// -static LogicalResult verify(MulPrimOp op) { +static LogicalResult verify(PrimMulOp op) { auto AIsVector = op.A().getType().isa(); auto BIsVector = op.B().getType().isa(); auto CIsVector = op.C().getType().isa(); @@ -50,7 +50,7 @@ static LogicalResult verify(MulPrimOp op) { return failure(); } -bool MulPrimOp::isPackMul() { +bool PrimMulOp::isPackMul() { auto AIsVector = A().getType().isa(); auto BIsVector = B().getType().isa(); return (AIsVector && !BIsVector) || (!AIsVector && BIsVector); @@ -194,22 +194,22 @@ void FuncDirectiveAttr::print(AsmPrinter &p) const { //===----------------------------------------------------------------------===// namespace { -struct SimplifyCastPrimOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct SimplifyPrimCastOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CastPrimOp cast, + LogicalResult matchAndRewrite(PrimCastOp cast, PatternRewriter &rewriter) const override { - if (cast.in().getType() == cast.out().getType()) { - rewriter.replaceOp(cast, cast.in()); + if (cast.input().getType() == cast.output().getType()) { + rewriter.replaceOp(cast, cast.input()); return success(); } // If the input of the cast is defined by another cast, then the two casts // can be merged into one. - if (cast.in().hasOneUse()) - if (auto defCast = cast.in().getDefiningOp()) { - rewriter.replaceOpWithNewOp(cast, cast.getType(), - defCast.in()); + if (cast.input().hasOneUse()) + if (auto defCast = cast.input().getDefiningOp()) { + rewriter.replaceOpWithNewOp(cast, cast.getType(), + defCast.input()); return success(); } @@ -218,19 +218,19 @@ struct SimplifyCastPrimOp : public OpRewritePattern { }; } // namespace -void CastPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, +void PrimCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } namespace { -struct SimplifyAssignOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct SimplifyBufferOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AssignOp assign, + LogicalResult matchAndRewrite(BufferOp buffer, PatternRewriter &rewriter) const override { - if (auto defOp = assign.input().getDefiningOp()) { - assign.inputMutable().assign(defOp.input()); + if (auto defOp = buffer.input().getDefiningOp()) { + buffer.inputMutable().assign(defOp.input()); return success(); } return failure(); @@ -238,9 +238,9 @@ struct SimplifyAssignOp : public OpRewritePattern { }; } // namespace -void AssignOp::getCanonicalizationPatterns(RewritePatternSet &results, +void BufferOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// // Include tablegen classes diff --git a/lib/Transforms/Directive/CreateHLSCppPrimitive.cpp b/lib/Transforms/Directive/CreateHLSCppPrimitive.cpp index 2525533..2a296ea 100644 --- a/lib/Transforms/Directive/CreateHLSCppPrimitive.cpp +++ b/lib/Transforms/Directive/CreateHLSCppPrimitive.cpp @@ -39,15 +39,15 @@ struct AddOpRewritePattern : public OpRewritePattern { // Cast add op operand from the new type. auto loc = add.getLoc(); rewriter.setInsertionPoint(add); - auto newLhs = rewriter.create(loc, newType, add.getLhs()); - auto newRhs = rewriter.create(loc, newType, add.getRhs()); + auto newLhs = rewriter.create(loc, newType, add.getLhs()); + auto newRhs = rewriter.create(loc, newType, add.getRhs()); add.getLhsMutable().assign(newLhs); add.getRhsMutable().assign(newRhs); // Cast add op result to the new type. rewriter.setInsertionPointAfter(add); auto cast = - rewriter.create(loc, add.getType(), add.getResult()); + rewriter.create(loc, add.getType(), add.getResult()); add.getResult().replaceAllUsesExcept(cast.getResult(), cast); add.getResult().setType(newType); @@ -87,8 +87,8 @@ struct MulOpRewritePattern : public OpRewritePattern { // Replace the original op with multiplication primitive op. auto loc = mul.getLoc(); rewriter.setInsertionPoint(mul); - auto mulResult = rewriter.create(loc, newType, lhs, rhs); - auto cast = rewriter.create(loc, mul.getType(), mulResult); + auto mulResult = rewriter.create(loc, newType, lhs, rhs); + auto cast = rewriter.create(loc, mul.getType(), mulResult); rewriter.replaceOp(mul, cast.getResult()); return success(); diff --git a/lib/Transforms/Graph/FuncDataflow.cpp b/lib/Transforms/Graph/FuncDataflow.cpp index 8992830..cf4ebec 100644 --- a/lib/Transforms/Graph/FuncDataflow.cpp +++ b/lib/Transforms/Graph/FuncDataflow.cpp @@ -171,7 +171,7 @@ static bool applyLegalizeDataflow(Block &block, int64_t gran, bool balance) { copyOp = builder.create(op->getLoc(), values.back(), newValue); } else { - copyOp = builder.create( + copyOp = builder.create( op->getLoc(), value.getType(), values.back()); newValue = copyOp->getResult(0); } diff --git a/lib/Transforms/LegalizeToHLSCpp.cpp b/lib/Transforms/LegalizeToHLSCpp.cpp index 6892847..71c1c77 100644 --- a/lib/Transforms/LegalizeToHLSCpp.cpp +++ b/lib/Transforms/LegalizeToHLSCpp.cpp @@ -88,18 +88,18 @@ bool scalehls::applyLegalizeToHLSCpp(FuncOp func, bool isTopFunc, func.getResultTypes())); } - // Insert AssignOp when an arguments or result of ConstantOp are directly + // Insert BufferOp when an arguments or result of ConstantOp are directly // connected to ReturnOp. auto returnOp = func.front().getTerminator(); builder.setInsertionPoint(returnOp); unsigned idx = 0; for (auto operand : returnOp->getOperands()) { if (operand.dyn_cast()) { - auto value = builder.create(returnOp->getLoc(), + auto value = builder.create(returnOp->getLoc(), operand.getType(), operand); returnOp->setOperand(idx, value); } else if (isa(operand.getDefiningOp())) { - auto value = builder.create(returnOp->getLoc(), + auto value = builder.create(returnOp->getLoc(), operand.getType(), operand); returnOp->setOperand(idx, value); } diff --git a/lib/Transforms/Loop/ConvertCopyToAffineLoops.cpp b/lib/Transforms/Loop/ConvertCopyToAffineLoops.cpp index 1a429f0..9ee5bd5 100644 --- a/lib/Transforms/Loop/ConvertCopyToAffineLoops.cpp +++ b/lib/Transforms/Loop/ConvertCopyToAffineLoops.cpp @@ -73,10 +73,10 @@ private: } // namespace namespace { -struct AssignOpRewritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct BufferOpRewritePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AssignOp assign, + LogicalResult matchAndRewrite(BufferOp assign, PatternRewriter &rewriter) const override { if (!assign->hasOneUse()) return failure(); @@ -233,7 +233,7 @@ struct ConvertCopyToAffineLoops // Simplify alloc and copy ops. mlir::RewritePatternSet patterns(context); patterns.add(context, DT); - patterns.add(context); + patterns.add(context); patterns.add(context); (void)applyPatternsAndFoldGreedily(module, std::move(patterns)); diff --git a/lib/Translation/EmitHLSCpp.cpp b/lib/Translation/EmitHLSCpp.cpp index dd091db..8fb9914 100644 --- a/lib/Translation/EmitHLSCpp.cpp +++ b/lib/Translation/EmitHLSCpp.cpp @@ -225,7 +225,7 @@ public: void emitMemrefToTensor(bufferization::ToTensorOp op); /// HLSCpp primitive operation emitters. - void emitMulPrim(MulPrimOp op); + void emitPrimMul(PrimMulOp op); template void emitAssign(AssignOpType op); /// Control flow operation emitters. @@ -423,9 +423,9 @@ public: return emitter.emitMemrefToTensor(op), true; } /// HLSCpp primitive operations. - bool visitOp(MulPrimOp op) { return emitter.emitMulPrim(op), true; } - bool visitOp(CastPrimOp op) { return emitter.emitAssign(op), true; } - bool visitOp(AssignOp op) { return emitter.emitAssign(op), true; } + bool visitOp(PrimMulOp op) { return emitter.emitPrimMul(op), true; } + bool visitOp(PrimCastOp op) { return emitter.emitAssign(op), true; } + bool visitOp(BufferOp op) { return emitter.emitAssign(op), true; } /// Control flow operations. bool visitOp(func::CallOp op) { return emitter.emitCall(op), true; } @@ -1268,7 +1268,7 @@ void ModuleEmitter::emitMemrefToTensor(bufferization::ToTensorOp op) { } /// HLSCpp primitive operation emitters. -void ModuleEmitter::emitMulPrim(MulPrimOp op) { +void ModuleEmitter::emitPrimMul(PrimMulOp op) { if (op.isPackMul()) { // Declare the result C array. if (!isDeclared(op.C())) { @@ -1825,7 +1825,7 @@ using namespace std; )XXX"; // Emit the multiplication primitive if required. - if (module.walk([](MulPrimOp op) { + if (module.walk([](PrimMulOp op) { return op.isPackMul() ? WalkResult::interrupt() : WalkResult::advance(); }) == WalkResult::interrupt()) os << R"XXX( diff --git a/test/Transforms/Directive/create_hlscpp_primitive.mlir b/test/Transforms/Directive/create_hlscpp_primitive.mlir index c913ecf..e758182 100644 --- a/test/Transforms/Directive/create_hlscpp_primitive.mlir +++ b/test/Transforms/Directive/create_hlscpp_primitive.mlir @@ -21,14 +21,14 @@ module { %2 = vector.transfer_read %arg1[%arg7, %arg8, %arg9, %arg6], %c0_i8 : memref<3x3x64x64xi8, #map1>, vector<2xi8> %3 = vector.transfer_read %arg2[%arg3, %arg4, %arg5, %arg6], %c0_i8 : memref<1x32x32x64xi8, #map2>, vector<2xi8> - // CHECK: %3 = "hlscpp.mul_prim"(%0, %1) : (i8, vector<2xi8>) -> vector<2xi16> - // CHECK: %4 = "hlscpp.cast_prim"(%3) : (vector<2xi16>) -> vector<2xi8> + // CHECK: %3 = "hlscpp.prim.mul"(%0, %1) : (i8, vector<2xi8>) -> vector<2xi16> + // CHECK: %4 = "hlscpp.prim.cast"(%3) : (vector<2xi16>) -> vector<2xi8> %4 = arith.muli %1, %2 : vector<2xi8> - // CHECK: %5 = "hlscpp.cast_prim"(%2) : (vector<2xi8>) -> vector<2xi32> - // CHECK: %6 = "hlscpp.cast_prim"(%4) : (vector<2xi8>) -> vector<2xi32> + // CHECK: %5 = "hlscpp.prim.cast"(%2) : (vector<2xi8>) -> vector<2xi32> + // CHECK: %6 = "hlscpp.prim.cast"(%4) : (vector<2xi8>) -> vector<2xi32> // CHECK: %7 = arith.addi %5, %6 : vector<2xi32> - // CHECK: %8 = "hlscpp.cast_prim"(%7) : (vector<2xi32>) -> vector<2xi8> + // CHECK: %8 = "hlscpp.prim.cast"(%7) : (vector<2xi32>) -> vector<2xi8> %5 = arith.addi %3, %4 : vector<2xi8> vector.transfer_write %5, %arg2[%arg3, %arg4, %arg5, %arg6] : vector<2xi8>, memref<1x32x32x64xi8, #map2> %6 = affine.apply #map3(%arg4) diff --git a/test/Transforms/Graph/func_dataflow.mlir b/test/Transforms/Graph/func_dataflow.mlir index c62ca35..2cd2bc9 100644 --- a/test/Transforms/Graph/func_dataflow.mlir +++ b/test/Transforms/Graph/func_dataflow.mlir @@ -12,8 +12,8 @@ module { // CHECK: %2 = "tosa.clamp" // CHECK: %3 = "tosa.conv2d" // CHECK: %4 = "tosa.clamp" - // CHECK: %5 = "hlscpp.assign" - // CHECK: %6 = "hlscpp.assign" + // CHECK: %5 = "hlscpp.buffer" + // CHECK: %6 = "hlscpp.buffer" // CHECK: return %4, %6 // CHECK: } @@ -26,7 +26,7 @@ module { // CHECK: func @dataflow3(%arg0: tensor<1x32x32x64xi8>, %arg1: tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> { // CHECK: %2 = "tosa.conv2d" - // CHECK: %3 = "hlscpp.assign" + // CHECK: %3 = "hlscpp.buffer" // CHECK: %4 = "tosa.add" // CHECK: %5 = "tosa.clamp" // CHECK: return %5 diff --git a/test/Transforms/Loop/convert_copy_to_affine_loops.mlir b/test/Transforms/Loop/convert_copy_to_affine_loops.mlir index bbcd405..e9e97da 100644 --- a/test/Transforms/Loop/convert_copy_to_affine_loops.mlir +++ b/test/Transforms/Loop/convert_copy_to_affine_loops.mlir @@ -28,11 +28,11 @@ module { } // CHECK-NOT: %[[VAL2:.*]] = bufferization.to_tensor %[[VAL0:.*]] : memref<1x32x32x64xi8> - // CHECK-NOT: %[[VAL3:.*]] = "hlscpp.assign"(%[[VAL2:.*]]) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> + // CHECK-NOT: %[[VAL3:.*]] = "hlscpp.buffer"(%[[VAL2:.*]]) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> // CHECK-NOT: %[[VAL4:.*]] = bufferization.to_memref %[[VAL3:.*]] : memref<1x32x32x64xi8> // CHECK-NOT: memref.copy %[[VAL4:.*]], %arg1 : memref<1x32x32x64xi8> to memref<1x32x32x64xi8> %2 = bufferization.to_tensor %1 : memref<1x32x32x64xi8> - %10 = "hlscpp.assign"(%2) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> + %10 = "hlscpp.buffer"(%2) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> %11 = bufferization.to_memref %10 : memref<1x32x32x64xi8> memref.copy %11, %arg1 : memref<1x32x32x64xi8> to memref<1x32x32x64xi8> return diff --git a/test/Transforms/create_runtime_main.mlir b/test/Transforms/create_runtime_main.mlir index fe51169..e8ac0a4 100644 --- a/test/Transforms/create_runtime_main.mlir +++ b/test/Transforms/create_runtime_main.mlir @@ -18,7 +18,7 @@ module { %0 = "tosa.clamp"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> %1 = "tosa.conv2d"(%0, %cst, %cst_0) {dilation = [1, 1], pad = [1, 1, 1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}, stride = [1, 1]} : (tensor<1x32x32x64xi8>, tensor<64x3x3x64xi8>, tensor<64xi8>) -> tensor<1x32x32x64xi8> %2 = "tosa.clamp"(%1) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> - %3 = "hlscpp.assign"(%0) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> + %3 = "hlscpp.buffer"(%0) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> return %2, %3 : tensor<1x32x32x64xi8>, tensor<1x32x32x64xi8> } @@ -40,7 +40,7 @@ module { %cst = arith.constant dense<2> : tensor<64x3x3x64xi8> %cst_0 = arith.constant dense<5> : tensor<64xi8> %0 = "tosa.conv2d"(%arg0, %cst, %cst_0) {dilation = [1, 1], pad = [1, 1, 1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}, stride = [1, 1]} : (tensor<1x32x32x64xi8>, tensor<64x3x3x64xi8>, tensor<64xi8>) -> tensor<1x32x32x64xi8> - %1 = "hlscpp.assign"(%arg1) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> + %1 = "hlscpp.buffer"(%arg1) : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> %2 = "tosa.add"(%0, %1) : (tensor<1x32x32x64xi8>, tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> %3 = "tosa.clamp"(%2) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x32x32x64xi8>) -> tensor<1x32x32x64xi8> return %3 : tensor<1x32x32x64xi8> diff --git a/test/Transforms/legalize_to_hlscpp.mlir b/test/Transforms/legalize_to_hlscpp.mlir index bdccaa8..297bf11 100644 --- a/test/Transforms/legalize_to_hlscpp.mlir +++ b/test/Transforms/legalize_to_hlscpp.mlir @@ -46,16 +46,16 @@ module { // ----- -// CHECK-LABEL: func @test_assign( +// CHECK-LABEL: func @test_buffer( // CHECK-SAME: %arg0: f32, %arg1: memref<16xf32, 1>) -> (f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1>) -func @test_assign(%arg0: f32, %arg1: memref<16xf32, 1>) -> (f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1>) { +func @test_buffer(%arg0: f32, %arg1: memref<16xf32, 1>) -> (f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1>) { %c11_i32 = arith.constant 11 : i32 %cst = arith.constant dense<[[11, 0], [0, -42]]> : tensor<2x2xi32> %cst_memref = bufferization.to_memref %cst : memref<2x2xi32, 1> - // CHECK: %1 = "hlscpp.assign"(%arg0) : (f32) -> f32 - // CHECK: %2 = "hlscpp.assign"(%arg1) : (memref<16xf32, 1>) -> memref<16xf32, 1> - // CHECK: %3 = "hlscpp.assign"(%c11_i32) : (i32) -> i32 + // CHECK: %1 = "hlscpp.buffer"(%arg0) : (f32) -> f32 + // CHECK: %2 = "hlscpp.buffer"(%arg1) : (memref<16xf32, 1>) -> memref<16xf32, 1> + // CHECK: %3 = "hlscpp.buffer"(%c11_i32) : (i32) -> i32 // CHECK: return %1, %2, %3, %0 : f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1> return %arg0, %arg1, %c11_i32, %cst_memref : f32, memref<16xf32, 1>, i32, memref<2x2xi32, 1> }