[EmitHLSCpp] Support to emit stream.channel/read/write op; [HLSCpp] stream.write must take a value as input now; [CreateTokenFlow][HoistStreamChannel] Update these passes accordingly
This commit is contained in:
parent
0ea2f79303
commit
0b8c03d300
|
@ -37,7 +37,7 @@ def StreamReadOp : HLSCppOp<"stream.read"> {
|
|||
def StreamWriteOp : HLSCppOp<"stream.write"> {
|
||||
let summary = "Stream channel write operation";
|
||||
|
||||
let arguments = (ins StreamOf<[AnyType]>:$channel, Optional<AnyType>:$value);
|
||||
let arguments = (ins StreamOf<[AnyType]>:$channel, AnyType:$value);
|
||||
}
|
||||
|
||||
def StreamBufferOp : HLSCppOp<"stream.buffer"> {
|
||||
|
|
|
@ -52,8 +52,9 @@ public:
|
|||
memref::ExpandShapeOp, memref::ReinterpretCastOp,
|
||||
bufferization::ToMemrefOp, bufferization::ToTensorOp,
|
||||
|
||||
// HLSCpp primitive operations.
|
||||
PrimMulOp, PrimCastOp, BufferOp,
|
||||
// HLSCpp dialect operations.
|
||||
StreamChannelOp, StreamReadOp, StreamWriteOp, PrimMulOp, PrimCastOp,
|
||||
BufferOp,
|
||||
|
||||
// Control flow operations.
|
||||
func::CallOp, func::ReturnOp,
|
||||
|
@ -145,7 +146,10 @@ public:
|
|||
HANDLE(bufferization::ToMemrefOp);
|
||||
HANDLE(bufferization::ToTensorOp);
|
||||
|
||||
// HLSCpp primitive operations.
|
||||
// HLSCpp dialect operations.
|
||||
HANDLE(StreamChannelOp);
|
||||
HANDLE(StreamReadOp);
|
||||
HANDLE(StreamWriteOp);
|
||||
HANDLE(PrimMulOp);
|
||||
HANDLE(PrimCastOp);
|
||||
HANDLE(BufferOp);
|
||||
|
|
|
@ -51,7 +51,9 @@ struct CreateTokenFlow : public CreateTokenFlowBase<CreateTokenFlow> {
|
|||
// Create a new token stream channel.
|
||||
builder.setInsertionPoint(returnOp);
|
||||
auto channel = builder.create<hlscpp::StreamChannelOp>(loc, tokenType);
|
||||
builder.create<hlscpp::StreamWriteOp>(loc, channel, Value());
|
||||
auto tokenValue =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false));
|
||||
builder.create<hlscpp::StreamWriteOp>(loc, channel, tokenValue);
|
||||
|
||||
// Collect the users of the stream channel.
|
||||
for (auto user : result.getUsers()) {
|
||||
|
|
|
@ -129,7 +129,10 @@ struct LowerStreamBufferOpRewritePattern
|
|||
buffer, buffer.getType());
|
||||
|
||||
rewriter.setInsertionPoint(block->getTerminator());
|
||||
rewriter.create<hlscpp::StreamWriteOp>(loc, channel, Value());
|
||||
auto tokenType = channel.getType().cast<StreamType>().getElementType();
|
||||
auto tokenValue = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(tokenType));
|
||||
rewriter.create<hlscpp::StreamWriteOp>(loc, channel, tokenValue);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -26,6 +26,8 @@ static SmallString<16> getTypeName(Value val) {
|
|||
auto valType = val.getType();
|
||||
if (auto arrayType = val.getType().dyn_cast<ShapedType>())
|
||||
valType = arrayType.getElementType();
|
||||
else if (auto streamType = val.getType().dyn_cast<StreamType>())
|
||||
valType = streamType.getElementType();
|
||||
|
||||
// Handle float types.
|
||||
if (valType.isa<Float32Type>())
|
||||
|
@ -224,7 +226,10 @@ public:
|
|||
void emitTensorToMemref(bufferization::ToMemrefOp op);
|
||||
void emitMemrefToTensor(bufferization::ToTensorOp op);
|
||||
|
||||
/// HLSCpp primitive operation emitters.
|
||||
/// HLSCpp dialect operation emitters.
|
||||
void emitStreamChannel(StreamChannelOp op);
|
||||
void emitStreamRead(StreamReadOp op);
|
||||
void emitStreamWrite(StreamWriteOp op);
|
||||
void emitPrimMul(PrimMulOp op);
|
||||
template <typename AssignOpType> void emitAssign(AssignOpType op);
|
||||
|
||||
|
@ -249,7 +254,8 @@ private:
|
|||
SmallVector<SmallString<8>, 4> getTransferIndices(TransferOpType op);
|
||||
|
||||
/// C++ component emitters.
|
||||
void emitValue(Value val, unsigned rank = 0, bool isPtr = false);
|
||||
void emitValue(Value val, unsigned rank = 0, bool isPtr = false,
|
||||
bool isRef = false);
|
||||
void emitArrayDecl(Value array);
|
||||
unsigned emitNestedLoopHeader(Value val);
|
||||
void emitNestedLoopFooter(unsigned rank);
|
||||
|
@ -422,7 +428,13 @@ public:
|
|||
bool visitOp(bufferization::ToTensorOp op) {
|
||||
return emitter.emitMemrefToTensor(op), true;
|
||||
}
|
||||
/// HLSCpp primitive operations.
|
||||
|
||||
/// HLSCpp dialect operations.
|
||||
bool visitOp(StreamChannelOp op) {
|
||||
return emitter.emitStreamChannel(op), true;
|
||||
}
|
||||
bool visitOp(StreamReadOp op) { return emitter.emitStreamRead(op), true; }
|
||||
bool visitOp(StreamWriteOp op) { return emitter.emitStreamWrite(op), true; }
|
||||
bool visitOp(PrimMulOp op) { return emitter.emitPrimMul(op), true; }
|
||||
bool visitOp(PrimCastOp op) { return emitter.emitAssign(op), true; }
|
||||
bool visitOp(BufferOp op) { return emitter.emitAssign(op), true; }
|
||||
|
@ -572,8 +584,7 @@ bool ExprVisitor::visitOp(arith::CmpIOp op) {
|
|||
|
||||
/// SCF statement emitters.
|
||||
void ModuleEmitter::emitScfFor(scf::ForOp op) {
|
||||
indent();
|
||||
os << "for (";
|
||||
indent() << "for (";
|
||||
auto iterVar = op.getInductionVar();
|
||||
|
||||
// Emit lower bound.
|
||||
|
@ -601,8 +612,7 @@ void ModuleEmitter::emitScfFor(scf::ForOp op) {
|
|||
emitBlock(*op.getBody());
|
||||
reduceIndent();
|
||||
|
||||
indent();
|
||||
os << "}\n";
|
||||
indent() << "}\n";
|
||||
}
|
||||
|
||||
void ModuleEmitter::emitScfIf(scf::IfOp op) {
|
||||
|
@ -619,8 +629,7 @@ void ModuleEmitter::emitScfIf(scf::IfOp op) {
|
|||
}
|
||||
}
|
||||
|
||||
indent();
|
||||
os << "if (";
|
||||
indent() << "if (";
|
||||
emitValue(op.getCondition());
|
||||
os << ") {";
|
||||
emitInfoAndNewLine(op);
|
||||
|
@ -630,15 +639,13 @@ void ModuleEmitter::emitScfIf(scf::IfOp op) {
|
|||
reduceIndent();
|
||||
|
||||
if (!op.getElseRegion().empty()) {
|
||||
indent();
|
||||
os << "} else {\n";
|
||||
indent() << "} else {\n";
|
||||
addIndent();
|
||||
emitBlock(op.getElseRegion().front());
|
||||
reduceIndent();
|
||||
}
|
||||
|
||||
indent();
|
||||
os << "}\n";
|
||||
indent() << "}\n";
|
||||
}
|
||||
|
||||
void ModuleEmitter::emitScfYield(scf::YieldOp op) {
|
||||
|
@ -664,8 +671,7 @@ void ModuleEmitter::emitScfYield(scf::YieldOp op) {
|
|||
|
||||
/// Affine statement emitters.
|
||||
void ModuleEmitter::emitAffineFor(AffineForOp op) {
|
||||
indent();
|
||||
os << "for (";
|
||||
indent() << "for (";
|
||||
auto iterVar = op.getInductionVar();
|
||||
|
||||
// Emit lower bound.
|
||||
|
@ -719,8 +725,7 @@ void ModuleEmitter::emitAffineFor(AffineForOp op) {
|
|||
emitBlock(*op.getBody());
|
||||
reduceIndent();
|
||||
|
||||
indent();
|
||||
os << "}\n";
|
||||
indent() << "}\n";
|
||||
}
|
||||
|
||||
void ModuleEmitter::emitAffineIf(AffineIfOp op) {
|
||||
|
@ -737,8 +742,7 @@ void ModuleEmitter::emitAffineIf(AffineIfOp op) {
|
|||
}
|
||||
}
|
||||
|
||||
indent();
|
||||
os << "if (";
|
||||
indent() << "if (";
|
||||
auto constrSet = op.getIntegerSet();
|
||||
AffineExprEmitter constrEmitter(state, constrSet.getNumDims(),
|
||||
op.getOperands());
|
||||
|
@ -763,15 +767,13 @@ void ModuleEmitter::emitAffineIf(AffineIfOp op) {
|
|||
reduceIndent();
|
||||
|
||||
if (op.hasElse()) {
|
||||
indent();
|
||||
os << "} else {\n";
|
||||
indent() << "} else {\n";
|
||||
addIndent();
|
||||
emitBlock(*op.getElseBlock());
|
||||
reduceIndent();
|
||||
}
|
||||
|
||||
indent();
|
||||
os << "}\n";
|
||||
indent() << "}\n";
|
||||
}
|
||||
|
||||
void ModuleEmitter::emitAffineParallel(AffineParallelOp op) {
|
||||
|
@ -790,8 +792,7 @@ void ModuleEmitter::emitAffineParallel(AffineParallelOp op) {
|
|||
|
||||
auto steps = getIntArrayAttrValue(op, op.getStepsAttrName());
|
||||
for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
|
||||
indent();
|
||||
os << "for (";
|
||||
indent() << "for (";
|
||||
auto iterVar = op.getBody()->getArgument(i);
|
||||
|
||||
// Emit lower bound.
|
||||
|
@ -825,8 +826,7 @@ void ModuleEmitter::emitAffineParallel(AffineParallelOp op) {
|
|||
for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
|
||||
reduceIndent();
|
||||
|
||||
indent();
|
||||
os << "}\n";
|
||||
indent() << "}\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -918,8 +918,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) {
|
|||
emitNestedLoopFooter(rank);
|
||||
}
|
||||
} else if (auto parentOp = dyn_cast<AffineParallelOp>(op->getParentOp())) {
|
||||
indent();
|
||||
os << "if (";
|
||||
indent() << "if (";
|
||||
unsigned ivIdx = 0;
|
||||
for (auto iv : parentOp.getBody()->getArguments()) {
|
||||
emitValue(iv);
|
||||
|
@ -945,8 +944,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) {
|
|||
}
|
||||
reduceIndent();
|
||||
|
||||
indent();
|
||||
os << "} else {\n";
|
||||
indent() << "} else {\n";
|
||||
|
||||
// Otherwise, generated values will be accumulated/reduced to the
|
||||
// current results with corresponding arith::AtomicRMWKind operations.
|
||||
|
@ -1006,8 +1004,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) {
|
|||
}
|
||||
reduceIndent();
|
||||
|
||||
indent();
|
||||
os << "}\n";
|
||||
indent() << "}\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1212,8 +1209,7 @@ template <typename OpType> void ModuleEmitter::emitReshape(OpType op) {
|
|||
assert(!isDeclared(array) && "has been declared before.");
|
||||
|
||||
auto arrayType = array.getType().template cast<ShapedType>();
|
||||
indent();
|
||||
os << getTypeName(array) << " (*";
|
||||
indent() << getTypeName(array) << " (*";
|
||||
|
||||
// Add the new value to nameTable and emit its name.
|
||||
os << addName(array, false);
|
||||
|
@ -1267,7 +1263,35 @@ void ModuleEmitter::emitMemrefToTensor(bufferization::ToTensorOp op) {
|
|||
}
|
||||
}
|
||||
|
||||
/// HLSCpp primitive operation emitters.
|
||||
/// HLSCpp dialect operation emitters.
|
||||
void ModuleEmitter::emitStreamChannel(StreamChannelOp op) {
|
||||
indent();
|
||||
emitValue(op.channel());
|
||||
os << ";";
|
||||
emitInfoAndNewLine(op);
|
||||
}
|
||||
|
||||
void ModuleEmitter::emitStreamRead(StreamReadOp op) {
|
||||
indent();
|
||||
if (op.result()) {
|
||||
emitValue(op.result());
|
||||
os << " = ";
|
||||
}
|
||||
emitValue(op.channel());
|
||||
os << ".read(";
|
||||
os << ");";
|
||||
emitInfoAndNewLine(op);
|
||||
}
|
||||
|
||||
void ModuleEmitter::emitStreamWrite(StreamWriteOp op) {
|
||||
indent();
|
||||
emitValue(op.channel());
|
||||
os << ".write(";
|
||||
emitValue(op.value());
|
||||
os << ");";
|
||||
emitInfoAndNewLine(op);
|
||||
}
|
||||
|
||||
void ModuleEmitter::emitPrimMul(PrimMulOp op) {
|
||||
if (op.isPackMul()) {
|
||||
// Declare the result C array.
|
||||
|
@ -1334,8 +1358,7 @@ void ModuleEmitter::emitCall(func::CallOp op) {
|
|||
}
|
||||
|
||||
// Emit the function call.
|
||||
indent();
|
||||
os << op.getCallee() << "(";
|
||||
indent() << op.getCallee() << "(";
|
||||
|
||||
// Handle input arguments.
|
||||
unsigned argIdx = 0;
|
||||
|
@ -1470,7 +1493,8 @@ void ModuleEmitter::emitConstant(arith::ConstantOp op) {
|
|||
}
|
||||
|
||||
/// C++ component emitters.
|
||||
void ModuleEmitter::emitValue(Value val, unsigned rank, bool isPtr) {
|
||||
void ModuleEmitter::emitValue(Value val, unsigned rank, bool isPtr,
|
||||
bool isRef) {
|
||||
assert(!(rank && isPtr) && "should be either an array or a pointer.");
|
||||
|
||||
// Value has been declared before or is a constant number.
|
||||
|
@ -1481,8 +1505,14 @@ void ModuleEmitter::emitValue(Value val, unsigned rank, bool isPtr) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (val.getType().isa<StreamType>())
|
||||
os << "hls::stream<" << getTypeName(val) << "> ";
|
||||
else
|
||||
os << getTypeName(val) << " ";
|
||||
|
||||
if (isRef)
|
||||
os << "&";
|
||||
|
||||
// Add the new value to nameTable and emit its name.
|
||||
os << addName(val, isPtr);
|
||||
for (unsigned i = 0; i < rank; ++i)
|
||||
|
@ -1527,8 +1557,7 @@ unsigned ModuleEmitter::emitNestedLoopHeader(Value val) {
|
|||
// Create nested loop.
|
||||
unsigned dimIdx = 0;
|
||||
for (auto &shape : type.getShape()) {
|
||||
indent();
|
||||
os << "for (int iv" << dimIdx << " = 0; ";
|
||||
indent() << "for (int iv" << dimIdx << " = 0; ";
|
||||
os << "iv" << dimIdx << " < " << shape << "; ";
|
||||
os << "++iv" << dimIdx++ << ") {\n";
|
||||
|
||||
|
@ -1548,8 +1577,7 @@ void ModuleEmitter::emitNestedLoopFooter(unsigned rank) {
|
|||
for (unsigned i = 0; i < rank; ++i) {
|
||||
reduceIndent();
|
||||
|
||||
indent();
|
||||
os << "}\n";
|
||||
indent() << "}\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1589,14 +1617,10 @@ void ModuleEmitter::emitLoopDirectives(Operation *op) {
|
|||
if (!loopDirect)
|
||||
return;
|
||||
|
||||
if (loopDirect.getPipeline()) {
|
||||
indent();
|
||||
os << "#pragma HLS pipeline II=" << loopDirect.getTargetII() << "\n";
|
||||
|
||||
} else if (loopDirect.getDataflow()) {
|
||||
indent();
|
||||
os << "#pragma HLS dataflow\n";
|
||||
}
|
||||
if (loopDirect.getPipeline())
|
||||
indent() << "#pragma HLS pipeline II=" << loopDirect.getTargetII() << "\n";
|
||||
else if (loopDirect.getDataflow())
|
||||
indent() << "#pragma HLS dataflow\n";
|
||||
}
|
||||
|
||||
void ModuleEmitter::emitArrayDirectives(Value memref) {
|
||||
|
@ -1612,8 +1636,7 @@ void ModuleEmitter::emitArrayDirectives(Value memref) {
|
|||
if (factors[dim] != 1) {
|
||||
emitPragmaFlag = true;
|
||||
|
||||
indent();
|
||||
os << "#pragma HLS array_partition";
|
||||
indent() << "#pragma HLS array_partition";
|
||||
os << " variable=";
|
||||
emitValue(memref);
|
||||
|
||||
|
@ -1634,8 +1657,7 @@ void ModuleEmitter::emitArrayDirectives(Value memref) {
|
|||
if (kind != MemoryKind::DRAM && !isFullyPartitioned(type)) {
|
||||
emitPragmaFlag = true;
|
||||
|
||||
indent();
|
||||
os << "#pragma HLS resource";
|
||||
indent() << "#pragma HLS resource";
|
||||
os << " variable=";
|
||||
emitValue(memref);
|
||||
|
||||
|
@ -1651,6 +1673,16 @@ void ModuleEmitter::emitArrayDirectives(Value memref) {
|
|||
os << "\n";
|
||||
}
|
||||
|
||||
// Emit DRAM variable as stable.
|
||||
if (kind == MemoryKind::DRAM) {
|
||||
emitPragmaFlag = true;
|
||||
|
||||
indent() << "#pragma HLS stable";
|
||||
os << " variable=";
|
||||
emitValue(memref);
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
// Emit an empty line.
|
||||
if (emitPragmaFlag)
|
||||
os << "\n";
|
||||
|
@ -1660,8 +1692,7 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func,
|
|||
ArrayRef<Value> portList) {
|
||||
// Only top function should emit interface pragmas.
|
||||
if (hasTopFuncAttr(func)) {
|
||||
indent();
|
||||
os << "#pragma HLS interface s_axilite port=return bundle=ctrl\n";
|
||||
indent() << "#pragma HLS interface s_axilite port=return bundle=ctrl\n";
|
||||
|
||||
for (auto &port : portList) {
|
||||
// Array ports and scalar ports are handled separately. Here, we only
|
||||
|
@ -1669,8 +1700,7 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func,
|
|||
if (auto memrefType = port.getType().dyn_cast<MemRefType>()) {
|
||||
// Only emit interface pragma when the array is not fully partitioned.
|
||||
if (!isFullyPartitioned(memrefType)) {
|
||||
indent();
|
||||
os << "#pragma HLS interface";
|
||||
indent() << "#pragma HLS interface";
|
||||
// For now, we set the offset of all m_axi interfaces as slave.
|
||||
if (MemoryKind(memrefType.getMemorySpaceAsInt()) ==
|
||||
MemoryKind::DRAM) {
|
||||
|
@ -1684,8 +1714,7 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func,
|
|||
os << "\n";
|
||||
}
|
||||
} else {
|
||||
indent();
|
||||
os << "#pragma HLS interface s_axilite";
|
||||
indent() << "#pragma HLS interface s_axilite";
|
||||
os << " port=";
|
||||
|
||||
// TODO: This is a temporary solution.
|
||||
|
@ -1711,14 +1740,13 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func,
|
|||
return;
|
||||
|
||||
if (funcDirect.getPipeline()) {
|
||||
indent();
|
||||
os << "#pragma HLS pipeline II=" << funcDirect.getTargetInterval() << "\n";
|
||||
indent() << "#pragma HLS pipeline II=" << funcDirect.getTargetInterval()
|
||||
<< "\n";
|
||||
|
||||
// An empty line.
|
||||
os << "\n";
|
||||
} else if (funcDirect.getDataflow()) {
|
||||
indent();
|
||||
os << "#pragma HLS dataflow\n";
|
||||
indent() << "#pragma HLS dataflow\n";
|
||||
|
||||
// An empty line.
|
||||
os << "\n";
|
||||
|
@ -1758,6 +1786,8 @@ void ModuleEmitter::emitFunction(FuncOp func) {
|
|||
indent();
|
||||
if (arg.getType().isa<ShapedType>())
|
||||
emitArrayDecl(arg);
|
||||
else if (arg.getType().isa<StreamType>())
|
||||
emitValue(arg, /*rank=*/0, /*isPtr=*/false, /*isRef=*/true);
|
||||
else
|
||||
emitValue(arg);
|
||||
|
||||
|
|
Loading…
Reference in New Issue