[EmitHLSCpp] Support to emit stream.channel/read/write op; [HLSCpp] stream.write must take a value as input now; [CreateTokenFlow][HoistStreamChannel] Update these passes accordingly

This commit is contained in:
Hanchen Ye 2022-03-20 14:45:06 -05:00
parent 0ea2f79303
commit 0b8c03d300
5 changed files with 111 additions and 72 deletions

View File

@ -37,7 +37,7 @@ def StreamReadOp : HLSCppOp<"stream.read"> {
def StreamWriteOp : HLSCppOp<"stream.write"> {
let summary = "Stream channel write operation";
let arguments = (ins StreamOf<[AnyType]>:$channel, Optional<AnyType>:$value);
let arguments = (ins StreamOf<[AnyType]>:$channel, AnyType:$value);
}
def StreamBufferOp : HLSCppOp<"stream.buffer"> {

View File

@ -52,8 +52,9 @@ public:
memref::ExpandShapeOp, memref::ReinterpretCastOp,
bufferization::ToMemrefOp, bufferization::ToTensorOp,
// HLSCpp primitive operations.
PrimMulOp, PrimCastOp, BufferOp,
// HLSCpp dialect operations.
StreamChannelOp, StreamReadOp, StreamWriteOp, PrimMulOp, PrimCastOp,
BufferOp,
// Control flow operations.
func::CallOp, func::ReturnOp,
@ -145,7 +146,10 @@ public:
HANDLE(bufferization::ToMemrefOp);
HANDLE(bufferization::ToTensorOp);
// HLSCpp primitive operations.
// HLSCpp dialect operations.
HANDLE(StreamChannelOp);
HANDLE(StreamReadOp);
HANDLE(StreamWriteOp);
HANDLE(PrimMulOp);
HANDLE(PrimCastOp);
HANDLE(BufferOp);

View File

@ -51,7 +51,9 @@ struct CreateTokenFlow : public CreateTokenFlowBase<CreateTokenFlow> {
// Create a new token stream channel.
builder.setInsertionPoint(returnOp);
auto channel = builder.create<hlscpp::StreamChannelOp>(loc, tokenType);
builder.create<hlscpp::StreamWriteOp>(loc, channel, Value());
auto tokenValue =
builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false));
builder.create<hlscpp::StreamWriteOp>(loc, channel, tokenValue);
// Collect the users of the stream channel.
for (auto user : result.getUsers()) {

View File

@ -129,7 +129,10 @@ struct LowerStreamBufferOpRewritePattern
buffer, buffer.getType());
rewriter.setInsertionPoint(block->getTerminator());
rewriter.create<hlscpp::StreamWriteOp>(loc, channel, Value());
auto tokenType = channel.getType().cast<StreamType>().getElementType();
auto tokenValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(tokenType));
rewriter.create<hlscpp::StreamWriteOp>(loc, channel, tokenValue);
return success();
}
};

View File

@ -26,6 +26,8 @@ static SmallString<16> getTypeName(Value val) {
auto valType = val.getType();
if (auto arrayType = val.getType().dyn_cast<ShapedType>())
valType = arrayType.getElementType();
else if (auto streamType = val.getType().dyn_cast<StreamType>())
valType = streamType.getElementType();
// Handle float types.
if (valType.isa<Float32Type>())
@ -224,7 +226,10 @@ public:
void emitTensorToMemref(bufferization::ToMemrefOp op);
void emitMemrefToTensor(bufferization::ToTensorOp op);
/// HLSCpp primitive operation emitters.
/// HLSCpp dialect operation emitters.
void emitStreamChannel(StreamChannelOp op);
void emitStreamRead(StreamReadOp op);
void emitStreamWrite(StreamWriteOp op);
void emitPrimMul(PrimMulOp op);
template <typename AssignOpType> void emitAssign(AssignOpType op);
@ -249,7 +254,8 @@ private:
SmallVector<SmallString<8>, 4> getTransferIndices(TransferOpType op);
/// C++ component emitters.
void emitValue(Value val, unsigned rank = 0, bool isPtr = false);
void emitValue(Value val, unsigned rank = 0, bool isPtr = false,
bool isRef = false);
void emitArrayDecl(Value array);
unsigned emitNestedLoopHeader(Value val);
void emitNestedLoopFooter(unsigned rank);
@ -422,7 +428,13 @@ public:
bool visitOp(bufferization::ToTensorOp op) {
return emitter.emitMemrefToTensor(op), true;
}
/// HLSCpp primitive operations.
/// HLSCpp dialect operations.
bool visitOp(StreamChannelOp op) {
return emitter.emitStreamChannel(op), true;
}
bool visitOp(StreamReadOp op) { return emitter.emitStreamRead(op), true; }
bool visitOp(StreamWriteOp op) { return emitter.emitStreamWrite(op), true; }
bool visitOp(PrimMulOp op) { return emitter.emitPrimMul(op), true; }
bool visitOp(PrimCastOp op) { return emitter.emitAssign(op), true; }
bool visitOp(BufferOp op) { return emitter.emitAssign(op), true; }
@ -572,8 +584,7 @@ bool ExprVisitor::visitOp(arith::CmpIOp op) {
/// SCF statement emitters.
void ModuleEmitter::emitScfFor(scf::ForOp op) {
indent();
os << "for (";
indent() << "for (";
auto iterVar = op.getInductionVar();
// Emit lower bound.
@ -601,8 +612,7 @@ void ModuleEmitter::emitScfFor(scf::ForOp op) {
emitBlock(*op.getBody());
reduceIndent();
indent();
os << "}\n";
indent() << "}\n";
}
void ModuleEmitter::emitScfIf(scf::IfOp op) {
@ -619,8 +629,7 @@ void ModuleEmitter::emitScfIf(scf::IfOp op) {
}
}
indent();
os << "if (";
indent() << "if (";
emitValue(op.getCondition());
os << ") {";
emitInfoAndNewLine(op);
@ -630,15 +639,13 @@ void ModuleEmitter::emitScfIf(scf::IfOp op) {
reduceIndent();
if (!op.getElseRegion().empty()) {
indent();
os << "} else {\n";
indent() << "} else {\n";
addIndent();
emitBlock(op.getElseRegion().front());
reduceIndent();
}
indent();
os << "}\n";
indent() << "}\n";
}
void ModuleEmitter::emitScfYield(scf::YieldOp op) {
@ -664,8 +671,7 @@ void ModuleEmitter::emitScfYield(scf::YieldOp op) {
/// Affine statement emitters.
void ModuleEmitter::emitAffineFor(AffineForOp op) {
indent();
os << "for (";
indent() << "for (";
auto iterVar = op.getInductionVar();
// Emit lower bound.
@ -719,8 +725,7 @@ void ModuleEmitter::emitAffineFor(AffineForOp op) {
emitBlock(*op.getBody());
reduceIndent();
indent();
os << "}\n";
indent() << "}\n";
}
void ModuleEmitter::emitAffineIf(AffineIfOp op) {
@ -737,8 +742,7 @@ void ModuleEmitter::emitAffineIf(AffineIfOp op) {
}
}
indent();
os << "if (";
indent() << "if (";
auto constrSet = op.getIntegerSet();
AffineExprEmitter constrEmitter(state, constrSet.getNumDims(),
op.getOperands());
@ -763,15 +767,13 @@ void ModuleEmitter::emitAffineIf(AffineIfOp op) {
reduceIndent();
if (op.hasElse()) {
indent();
os << "} else {\n";
indent() << "} else {\n";
addIndent();
emitBlock(*op.getElseBlock());
reduceIndent();
}
indent();
os << "}\n";
indent() << "}\n";
}
void ModuleEmitter::emitAffineParallel(AffineParallelOp op) {
@ -790,8 +792,7 @@ void ModuleEmitter::emitAffineParallel(AffineParallelOp op) {
auto steps = getIntArrayAttrValue(op, op.getStepsAttrName());
for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
indent();
os << "for (";
indent() << "for (";
auto iterVar = op.getBody()->getArgument(i);
// Emit lower bound.
@ -825,8 +826,7 @@ void ModuleEmitter::emitAffineParallel(AffineParallelOp op) {
for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
reduceIndent();
indent();
os << "}\n";
indent() << "}\n";
}
}
@ -918,8 +918,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) {
emitNestedLoopFooter(rank);
}
} else if (auto parentOp = dyn_cast<AffineParallelOp>(op->getParentOp())) {
indent();
os << "if (";
indent() << "if (";
unsigned ivIdx = 0;
for (auto iv : parentOp.getBody()->getArguments()) {
emitValue(iv);
@ -945,8 +944,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) {
}
reduceIndent();
indent();
os << "} else {\n";
indent() << "} else {\n";
// Otherwise, generated values will be accumulated/reduced to the
// current results with corresponding arith::AtomicRMWKind operations.
@ -1006,8 +1004,7 @@ void ModuleEmitter::emitAffineYield(AffineYieldOp op) {
}
reduceIndent();
indent();
os << "}\n";
indent() << "}\n";
}
}
@ -1212,8 +1209,7 @@ template <typename OpType> void ModuleEmitter::emitReshape(OpType op) {
assert(!isDeclared(array) && "has been declared before.");
auto arrayType = array.getType().template cast<ShapedType>();
indent();
os << getTypeName(array) << " (*";
indent() << getTypeName(array) << " (*";
// Add the new value to nameTable and emit its name.
os << addName(array, false);
@ -1267,7 +1263,35 @@ void ModuleEmitter::emitMemrefToTensor(bufferization::ToTensorOp op) {
}
}
/// HLSCpp primitive operation emitters.
/// HLSCpp dialect operation emitters.
void ModuleEmitter::emitStreamChannel(StreamChannelOp op) {
indent();
emitValue(op.channel());
os << ";";
emitInfoAndNewLine(op);
}
void ModuleEmitter::emitStreamRead(StreamReadOp op) {
indent();
if (op.result()) {
emitValue(op.result());
os << " = ";
}
emitValue(op.channel());
os << ".read(";
os << ");";
emitInfoAndNewLine(op);
}
void ModuleEmitter::emitStreamWrite(StreamWriteOp op) {
indent();
emitValue(op.channel());
os << ".write(";
emitValue(op.value());
os << ");";
emitInfoAndNewLine(op);
}
void ModuleEmitter::emitPrimMul(PrimMulOp op) {
if (op.isPackMul()) {
// Declare the result C array.
@ -1334,8 +1358,7 @@ void ModuleEmitter::emitCall(func::CallOp op) {
}
// Emit the function call.
indent();
os << op.getCallee() << "(";
indent() << op.getCallee() << "(";
// Handle input arguments.
unsigned argIdx = 0;
@ -1470,7 +1493,8 @@ void ModuleEmitter::emitConstant(arith::ConstantOp op) {
}
/// C++ component emitters.
void ModuleEmitter::emitValue(Value val, unsigned rank, bool isPtr) {
void ModuleEmitter::emitValue(Value val, unsigned rank, bool isPtr,
bool isRef) {
assert(!(rank && isPtr) && "should be either an array or a pointer.");
// Value has been declared before or is a constant number.
@ -1481,7 +1505,13 @@ void ModuleEmitter::emitValue(Value val, unsigned rank, bool isPtr) {
return;
}
os << getTypeName(val) << " ";
if (val.getType().isa<StreamType>())
os << "hls::stream<" << getTypeName(val) << "> ";
else
os << getTypeName(val) << " ";
if (isRef)
os << "&";
// Add the new value to nameTable and emit its name.
os << addName(val, isPtr);
@ -1527,8 +1557,7 @@ unsigned ModuleEmitter::emitNestedLoopHeader(Value val) {
// Create nested loop.
unsigned dimIdx = 0;
for (auto &shape : type.getShape()) {
indent();
os << "for (int iv" << dimIdx << " = 0; ";
indent() << "for (int iv" << dimIdx << " = 0; ";
os << "iv" << dimIdx << " < " << shape << "; ";
os << "++iv" << dimIdx++ << ") {\n";
@ -1548,8 +1577,7 @@ void ModuleEmitter::emitNestedLoopFooter(unsigned rank) {
for (unsigned i = 0; i < rank; ++i) {
reduceIndent();
indent();
os << "}\n";
indent() << "}\n";
}
}
@ -1589,14 +1617,10 @@ void ModuleEmitter::emitLoopDirectives(Operation *op) {
if (!loopDirect)
return;
if (loopDirect.getPipeline()) {
indent();
os << "#pragma HLS pipeline II=" << loopDirect.getTargetII() << "\n";
} else if (loopDirect.getDataflow()) {
indent();
os << "#pragma HLS dataflow\n";
}
if (loopDirect.getPipeline())
indent() << "#pragma HLS pipeline II=" << loopDirect.getTargetII() << "\n";
else if (loopDirect.getDataflow())
indent() << "#pragma HLS dataflow\n";
}
void ModuleEmitter::emitArrayDirectives(Value memref) {
@ -1612,8 +1636,7 @@ void ModuleEmitter::emitArrayDirectives(Value memref) {
if (factors[dim] != 1) {
emitPragmaFlag = true;
indent();
os << "#pragma HLS array_partition";
indent() << "#pragma HLS array_partition";
os << " variable=";
emitValue(memref);
@ -1634,8 +1657,7 @@ void ModuleEmitter::emitArrayDirectives(Value memref) {
if (kind != MemoryKind::DRAM && !isFullyPartitioned(type)) {
emitPragmaFlag = true;
indent();
os << "#pragma HLS resource";
indent() << "#pragma HLS resource";
os << " variable=";
emitValue(memref);
@ -1651,6 +1673,16 @@ void ModuleEmitter::emitArrayDirectives(Value memref) {
os << "\n";
}
// Emit DRAM variable as stable.
if (kind == MemoryKind::DRAM) {
emitPragmaFlag = true;
indent() << "#pragma HLS stable";
os << " variable=";
emitValue(memref);
os << "\n";
}
// Emit an empty line.
if (emitPragmaFlag)
os << "\n";
@ -1660,8 +1692,7 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func,
ArrayRef<Value> portList) {
// Only top function should emit interface pragmas.
if (hasTopFuncAttr(func)) {
indent();
os << "#pragma HLS interface s_axilite port=return bundle=ctrl\n";
indent() << "#pragma HLS interface s_axilite port=return bundle=ctrl\n";
for (auto &port : portList) {
// Array ports and scalar ports are handled separately. Here, we only
@ -1669,8 +1700,7 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func,
if (auto memrefType = port.getType().dyn_cast<MemRefType>()) {
// Only emit interface pragma when the array is not fully partitioned.
if (!isFullyPartitioned(memrefType)) {
indent();
os << "#pragma HLS interface";
indent() << "#pragma HLS interface";
// For now, we set the offset of all m_axi interfaces as slave.
if (MemoryKind(memrefType.getMemorySpaceAsInt()) ==
MemoryKind::DRAM) {
@ -1684,8 +1714,7 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func,
os << "\n";
}
} else {
indent();
os << "#pragma HLS interface s_axilite";
indent() << "#pragma HLS interface s_axilite";
os << " port=";
// TODO: This is a temporary solution.
@ -1711,14 +1740,13 @@ void ModuleEmitter::emitFunctionDirectives(FuncOp func,
return;
if (funcDirect.getPipeline()) {
indent();
os << "#pragma HLS pipeline II=" << funcDirect.getTargetInterval() << "\n";
indent() << "#pragma HLS pipeline II=" << funcDirect.getTargetInterval()
<< "\n";
// An empty line.
os << "\n";
} else if (funcDirect.getDataflow()) {
indent();
os << "#pragma HLS dataflow\n";
indent() << "#pragma HLS dataflow\n";
// An empty line.
os << "\n";
@ -1758,6 +1786,8 @@ void ModuleEmitter::emitFunction(FuncOp func) {
indent();
if (arg.getType().isa<ShapedType>())
emitArrayDecl(arg);
else if (arg.getType().isa<StreamType>())
emitValue(arg, /*rank=*/0, /*isPtr=*/false, /*isRef=*/true);
else
emitValue(arg);