diff --git a/include/scalehls/Dialect/HLSCpp/Ops.td b/include/scalehls/Dialect/HLSCpp/Ops.td index 8048854..a35350b 100644 --- a/include/scalehls/Dialect/HLSCpp/Ops.td +++ b/include/scalehls/Dialect/HLSCpp/Ops.td @@ -37,7 +37,7 @@ def StreamReadOp : HLSCppOp<"stream.read"> { def StreamWriteOp : HLSCppOp<"stream.write"> { let summary = "Stream channel write operation"; - let arguments = (ins StreamOf<[AnyType]>:$channel, Optional:$value); + let arguments = (ins StreamOf<[AnyType]>:$channel, AnyType:$value); } def StreamBufferOp : HLSCppOp<"stream.buffer"> { diff --git a/include/scalehls/Dialect/HLSCpp/Visitor.h b/include/scalehls/Dialect/HLSCpp/Visitor.h index 4d95068..5dc905c 100644 --- a/include/scalehls/Dialect/HLSCpp/Visitor.h +++ b/include/scalehls/Dialect/HLSCpp/Visitor.h @@ -52,8 +52,9 @@ public: memref::ExpandShapeOp, memref::ReinterpretCastOp, bufferization::ToMemrefOp, bufferization::ToTensorOp, - // HLSCpp primitive operations. - PrimMulOp, PrimCastOp, BufferOp, + // HLSCpp dialect operations. + StreamChannelOp, StreamReadOp, StreamWriteOp, PrimMulOp, PrimCastOp, + BufferOp, // Control flow operations. func::CallOp, func::ReturnOp, @@ -145,7 +146,10 @@ public: HANDLE(bufferization::ToMemrefOp); HANDLE(bufferization::ToTensorOp); - // HLSCpp primitive operations. + // HLSCpp dialect operations. + HANDLE(StreamChannelOp); + HANDLE(StreamReadOp); + HANDLE(StreamWriteOp); HANDLE(PrimMulOp); HANDLE(PrimCastOp); HANDLE(BufferOp); diff --git a/lib/Transforms/Graph/CreateTokenFlow.cpp b/lib/Transforms/Graph/CreateTokenFlow.cpp index a9407d8..1482a2b 100644 --- a/lib/Transforms/Graph/CreateTokenFlow.cpp +++ b/lib/Transforms/Graph/CreateTokenFlow.cpp @@ -51,7 +51,9 @@ struct CreateTokenFlow : public CreateTokenFlowBase { // Create a new token stream channel. builder.setInsertionPoint(returnOp); auto channel = builder.create(loc, tokenType); - builder.create(loc, channel, Value()); + auto tokenValue = + builder.create(loc, builder.getBoolAttr(false)); + builder.create(loc, channel, tokenValue); // Collect the users of the stream channel. for (auto user : result.getUsers()) { diff --git a/lib/Transforms/Graph/HoistStreamChannel.cpp b/lib/Transforms/Graph/HoistStreamChannel.cpp index 8db1887..554a3de 100644 --- a/lib/Transforms/Graph/HoistStreamChannel.cpp +++ b/lib/Transforms/Graph/HoistStreamChannel.cpp @@ -129,7 +129,10 @@ struct LowerStreamBufferOpRewritePattern buffer, buffer.getType()); rewriter.setInsertionPoint(block->getTerminator()); - rewriter.create(loc, channel, Value()); + auto tokenType = channel.getType().cast().getElementType(); + auto tokenValue = rewriter.create( + loc, rewriter.getZeroAttr(tokenType)); + rewriter.create(loc, channel, tokenValue); return success(); } }; diff --git a/lib/Translation/EmitHLSCpp.cpp b/lib/Translation/EmitHLSCpp.cpp index 8fb9914..278b1da 100644 --- a/lib/Translation/EmitHLSCpp.cpp +++ b/lib/Translation/EmitHLSCpp.cpp @@ -26,6 +26,8 @@ static SmallString<16> getTypeName(Value val) { auto valType = val.getType(); if (auto arrayType = val.getType().dyn_cast()) valType = arrayType.getElementType(); + else if (auto streamType = val.getType().dyn_cast()) + valType = streamType.getElementType(); // Handle float types. if (valType.isa()) @@ -224,7 +226,10 @@ public: void emitTensorToMemref(bufferization::ToMemrefOp op); void emitMemrefToTensor(bufferization::ToTensorOp op); - /// HLSCpp primitive operation emitters. + /// HLSCpp dialect operation emitters. + void emitStreamChannel(StreamChannelOp op); + void emitStreamRead(StreamReadOp op); + void emitStreamWrite(StreamWriteOp op); void emitPrimMul(PrimMulOp op); template void emitAssign(AssignOpType op); @@ -249,7 +254,8 @@ private: SmallVector, 4> getTransferIndices(TransferOpType op); /// C++ component emitters. - void emitValue(Value val, unsigned rank = 0, bool isPtr = false); + void emitValue(Value val, unsigned rank = 0, bool isPtr = false, + bool isRef = false); void emitArrayDecl(Value array); unsigned emitNestedLoopHeader(Value val); void emitNestedLoopFooter(unsigned rank); @@ -422,7 +428,13 @@ public: bool visitOp(bufferization::ToTensorOp op) { return emitter.emitMemrefToTensor(op), true; } - /// HLSCpp primitive operations. + + /// HLSCpp dialect operations. + bool visitOp(StreamChannelOp op) { + return emitter.emitStreamChannel(op), true; + } + bool visitOp(StreamReadOp op) { return emitter.emitStreamRead(op), true; } + bool visitOp(StreamWriteOp op) { return emitter.emitStreamWrite(op), true; } bool visitOp(PrimMulOp op) { return emitter.emitPrimMul(op), true; } bool visitOp(PrimCastOp op) { return emitter.emitAssign(op), true; } bool visitOp(BufferOp op) { return emitter.emitAssign(op), true; } @@ -572,8 +584,7 @@ bool ExprVisitor::visitOp(arith::CmpIOp op) { /// SCF statement emitters. void ModuleEmitter::emitScfFor(scf::ForOp op) { - indent(); - os << "for ("; + indent() << "for ("; auto iterVar = op.getInductionVar(); // Emit lower bound. @@ -601,8 +612,7 @@ void ModuleEmitter::emitScfFor(scf::ForOp op) { emitBlock(*op.getBody()); reduceIndent(); - indent(); - os << "}\n"; + indent() << "}\n"; } void ModuleEmitter::emitScfIf(scf::IfOp op) { @@ -619,8 +629,7 @@ void ModuleEmitter::emitScfIf(scf::IfOp op) { } } - indent(); - os << "if ("; + indent() << "if ("; emitValue(op.getCondition()); os << ") {"; emitInfoAndNewLine(op); @@ -630,15 +639,13 @@ void ModuleEmitter::emitScfIf(scf::IfOp op) { reduceIndent(); if (!op.getElseRegion().empty()) { - indent(); - os << "} else {\n"; + indent() << "} else {\n"; addIndent(); emitBlock(op.getElseRegion().front()); reduceIndent(); } - indent(); - os << "}\n"; + indent() << "}\n"; } void ModuleEmitter::emitScfYield(scf::YieldOp op) { @@ -664,8 +671,7 @@ void ModuleEmitter::emitScfYield(scf::YieldOp op) { /// Affine statement emitters. void ModuleEmitter::emitAffineFor(AffineForOp op) { - indent(); - os << "for ("; + indent() << "for ("; auto iterVar = op.getInductionVar(); // Emit lower bound. @@ -719,8 +725,7 @@ void ModuleEmitter::emitAffineFor(AffineForOp op) { emitBlock(*op.getBody()); reduceIndent(); - indent(); - os << "}\n"; + indent() << "}\n"; } void ModuleEmitter::emitAffineIf(AffineIfOp op) { @@ -737,8 +742,7 @@ void ModuleEmitter::emitAffineIf(AffineIfOp op) { } } - indent(); - os << "if ("; + indent() << "if ("; auto constrSet = op.getIntegerSet(); AffineExprEmitter constrEmitter(state, constrSet.getNumDims(), op.getOperands()); @@ -763,15 +767,13 @@ void ModuleEmitter::emitAffineIf(AffineIfOp op) { reduceIndent(); if (op.hasElse()) { - indent(); - os << "} else {\n"; + indent() << "} else {\n"; addIndent(); emitBlock(*op.getElseBlock()); reduceIndent(); } - indent(); - os << "}\n"; + indent() << "}\n"; } void ModuleEmitter::emitAffineParallel(AffineParallelOp op) { @@ -790,8 +792,7 @@ void ModuleEmitter::emitAffineParallel(AffineParallelOp op) { auto steps = getIntArrayAttrValue(op, op.getStepsAttrName()); for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) { - indent(); - os << "for ("; + indent() << "for ("; auto iterVar = op.getBody()->getArgument(i); // Emit lower bound. @@ -825,8 +826,7 @@ void ModuleEmitter::emitAffineParallel(AffineParallelOp op) { for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) { reduceIndent(); - indent(); - os << "}\n"; + indent() << "}\n"; } } @@ -918,8 +918,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) { emitNestedLoopFooter(rank); } } else if (auto parentOp = dyn_cast(op->getParentOp())) { - indent(); - os << "if ("; + indent() << "if ("; unsigned ivIdx = 0; for (auto iv : parentOp.getBody()->getArguments()) { emitValue(iv); @@ -945,8 +944,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) { } reduceIndent(); - indent(); - os << "} else {\n"; + indent() << "} else {\n"; // Otherwise, generated values will be accumulated/reduced to the // current results with corresponding arith::AtomicRMWKind operations. @@ -1006,8 +1004,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) { } reduceIndent(); - indent(); - os << "}\n"; + indent() << "}\n"; } } @@ -1212,8 +1209,7 @@ template void ModuleEmitter::emitReshape(OpType op) { assert(!isDeclared(array) && "has been declared before."); auto arrayType = array.getType().template cast(); - indent(); - os << getTypeName(array) << " (*"; + indent() << getTypeName(array) << " (*"; // Add the new value to nameTable and emit its name. os << addName(array, false); @@ -1267,7 +1263,35 @@ void ModuleEmitter::emitMemrefToTensor(bufferization::ToTensorOp op) { } } -/// HLSCpp primitive operation emitters. +/// HLSCpp dialect operation emitters. +void ModuleEmitter::emitStreamChannel(StreamChannelOp op) { + indent(); + emitValue(op.channel()); + os << ";"; + emitInfoAndNewLine(op); +} + +void ModuleEmitter::emitStreamRead(StreamReadOp op) { + indent(); + if (op.result()) { + emitValue(op.result()); + os << " = "; + } + emitValue(op.channel()); + os << ".read("; + os << ");"; + emitInfoAndNewLine(op); +} + +void ModuleEmitter::emitStreamWrite(StreamWriteOp op) { + indent(); + emitValue(op.channel()); + os << ".write("; + emitValue(op.value()); + os << ");"; + emitInfoAndNewLine(op); +} + void ModuleEmitter::emitPrimMul(PrimMulOp op) { if (op.isPackMul()) { // Declare the result C array. @@ -1334,8 +1358,7 @@ void ModuleEmitter::emitCall(func::CallOp op) { } // Emit the function call. - indent(); - os << op.getCallee() << "("; + indent() << op.getCallee() << "("; // Handle input arguments. unsigned argIdx = 0; @@ -1470,7 +1493,8 @@ void ModuleEmitter::emitConstant(arith::ConstantOp op) { } /// C++ component emitters. -void ModuleEmitter::emitValue(Value val, unsigned rank, bool isPtr) { +void ModuleEmitter::emitValue(Value val, unsigned rank, bool isPtr, + bool isRef) { assert(!(rank && isPtr) && "should be either an array or a pointer."); // Value has been declared before or is a constant number. @@ -1481,7 +1505,13 @@ void ModuleEmitter::emitValue(Value val, unsigned rank, bool isPtr) { return; } - os << getTypeName(val) << " "; + if (val.getType().isa()) + os << "hls::stream<" << getTypeName(val) << "> "; + else + os << getTypeName(val) << " "; + + if (isRef) + os << "&"; // Add the new value to nameTable and emit its name. os << addName(val, isPtr); @@ -1527,8 +1557,7 @@ unsigned ModuleEmitter::emitNestedLoopHeader(Value val) { // Create nested loop. unsigned dimIdx = 0; for (auto &shape : type.getShape()) { - indent(); - os << "for (int iv" << dimIdx << " = 0; "; + indent() << "for (int iv" << dimIdx << " = 0; "; os << "iv" << dimIdx << " < " << shape << "; "; os << "++iv" << dimIdx++ << ") {\n"; @@ -1548,8 +1577,7 @@ void ModuleEmitter::emitNestedLoopFooter(unsigned rank) { for (unsigned i = 0; i < rank; ++i) { reduceIndent(); - indent(); - os << "}\n"; + indent() << "}\n"; } } @@ -1589,14 +1617,10 @@ void ModuleEmitter::emitLoopDirectives(Operation *op) { if (!loopDirect) return; - if (loopDirect.getPipeline()) { - indent(); - os << "#pragma HLS pipeline II=" << loopDirect.getTargetII() << "\n"; - - } else if (loopDirect.getDataflow()) { - indent(); - os << "#pragma HLS dataflow\n"; - } + if (loopDirect.getPipeline()) + indent() << "#pragma HLS pipeline II=" << loopDirect.getTargetII() << "\n"; + else if (loopDirect.getDataflow()) + indent() << "#pragma HLS dataflow\n"; } void ModuleEmitter::emitArrayDirectives(Value memref) { @@ -1612,8 +1636,7 @@ void ModuleEmitter::emitArrayDirectives(Value memref) { if (factors[dim] != 1) { emitPragmaFlag = true; - indent(); - os << "#pragma HLS array_partition"; + indent() << "#pragma HLS array_partition"; os << " variable="; emitValue(memref); @@ -1634,8 +1657,7 @@ void ModuleEmitter::emitArrayDirectives(Value memref) { if (kind != MemoryKind::DRAM && !isFullyPartitioned(type)) { emitPragmaFlag = true; - indent(); - os << "#pragma HLS resource"; + indent() << "#pragma HLS resource"; os << " variable="; emitValue(memref); @@ -1651,6 +1673,16 @@ void ModuleEmitter::emitArrayDirectives(Value memref) { os << "\n"; } + // Emit DRAM variable as stable. + if (kind == MemoryKind::DRAM) { + emitPragmaFlag = true; + + indent() << "#pragma HLS stable"; + os << " variable="; + emitValue(memref); + os << "\n"; + } + // Emit an empty line. if (emitPragmaFlag) os << "\n"; @@ -1660,8 +1692,7 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func, ArrayRef portList) { // Only top function should emit interface pragmas. if (hasTopFuncAttr(func)) { - indent(); - os << "#pragma HLS interface s_axilite port=return bundle=ctrl\n"; + indent() << "#pragma HLS interface s_axilite port=return bundle=ctrl\n"; for (auto &port : portList) { // Array ports and scalar ports are handled separately. Here, we only @@ -1669,8 +1700,7 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func, if (auto memrefType = port.getType().dyn_cast()) { // Only emit interface pragma when the array is not fully partitioned. if (!isFullyPartitioned(memrefType)) { - indent(); - os << "#pragma HLS interface"; + indent() << "#pragma HLS interface"; // For now, we set the offset of all m_axi interfaces as slave. if (MemoryKind(memrefType.getMemorySpaceAsInt()) == MemoryKind::DRAM) { @@ -1684,8 +1714,7 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func, os << "\n"; } } else { - indent(); - os << "#pragma HLS interface s_axilite"; + indent() << "#pragma HLS interface s_axilite"; os << " port="; // TODO: This is a temporary solution. @@ -1711,14 +1740,13 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func, return; if (funcDirect.getPipeline()) { - indent(); - os << "#pragma HLS pipeline II=" << funcDirect.getTargetInterval() << "\n"; + indent() << "#pragma HLS pipeline II=" << funcDirect.getTargetInterval() + << "\n"; // An empty line. os << "\n"; } else if (funcDirect.getDataflow()) { - indent(); - os << "#pragma HLS dataflow\n"; + indent() << "#pragma HLS dataflow\n"; // An empty line. os << "\n"; @@ -1758,6 +1786,8 @@ void ModuleEmitter::emitFunction(FuncOp func) { indent(); if (arg.getType().isa()) emitArrayDecl(arg); + else if (arg.getType().isa()) + emitValue(arg, /*rank=*/0, /*isPtr=*/false, /*isRef=*/true); else emitValue(arg);