[EmitHLSCpp] support tensor/vector function input/output; expand binary/unary operations emitter to handle tensor/vector input

This commit is contained in:
Hanchen Ye 2020-09-05 19:37:18 -05:00
parent d7556c8a0a
commit bdcfa4efe1
2 changed files with 226 additions and 73 deletions

View File

@ -294,10 +294,12 @@ public:
void emitAffineIf(AffineIfOp *op);
void emitAffineParallel(AffineParallelOp *op);
void emitAffineApply(AffineApplyOp *op);
void emitAffineMax(AffineMaxOp *op);
void emitAffineMin(AffineMinOp *op);
template <typename OpType>
void emitAffineMaxMin(OpType *op, const char *syntax);
void emitAffineLoad(AffineLoadOp *op);
void emitAffineStore(AffineStoreOp *op);
void emitAffineVectorLoad(AffineVectorLoadOp *op);
void emitAffineVectorStore(AffineVectorStoreOp *op);
void emitAffineYield(AffineYieldOp *op);
/// Memref-related statement emitters.
@ -316,14 +318,18 @@ public:
void emitRank(RankOp *op);
/// Standard expression emitters.
template <typename ArrayType>
void emitEltWiseBinary(Operation *op, const char *syntax, ArrayType type);
void emitBinary(Operation *op, const char *syntax);
template <typename ArrayType>
void emitEltWiseUnary(Operation *op, const char *syntax, ArrayType type);
void emitUnary(Operation *op, const char *syntax);
/// Special operation emitters.
void emitIndexCast(IndexCastOp *op);
void emitSelect(SelectOp *op);
template <typename ResultType>
void emitCppArray(ConstantOp *op);
template <typename ArrayType>
void emitConstantArray(ConstantOp *op, ArrayType type);
void emitConstant(ConstantOp *op);
void emitCall(CallOp *op);
@ -335,6 +341,8 @@ private:
void emitValue(Value val, bool isPtr = false);
void emitOperation(Operation *op);
void emitBlock(Block &block);
template <typename ArrayType>
void emitArrayDecl(Value val, ArrayType type);
void emitFunction(FuncOp func);
};
} // namespace
@ -418,10 +426,20 @@ public:
return emitter.emitAffineParallel(&op), true;
}
bool visitOp(AffineApplyOp op) { return emitter.emitAffineApply(&op), true; }
bool visitOp(AffineMaxOp op) { return emitter.emitAffineMax(&op), true; }
bool visitOp(AffineMinOp op) { return emitter.emitAffineMin(&op), true; }
bool visitOp(AffineMaxOp op) {
return emitter.emitAffineMaxMin<AffineMaxOp>(&op, "max"), true;
}
bool visitOp(AffineMinOp op) {
return emitter.emitAffineMaxMin<AffineMinOp>(&op, "min"), true;
}
bool visitOp(AffineLoadOp op) { return emitter.emitAffineLoad(&op), true; }
bool visitOp(AffineStoreOp op) { return emitter.emitAffineStore(&op), true; }
bool visitOp(AffineVectorLoadOp op) {
return emitter.emitAffineVectorLoad(&op), true;
}
bool visitOp(AffineVectorStoreOp op) {
return emitter.emitAffineVectorStore(&op), true;
}
bool visitOp(AffineYieldOp op) { return emitter.emitAffineYield(&op), true; }
/// Memref-related statements.
@ -728,7 +746,8 @@ void ModuleEmitter::emitAffineApply(AffineApplyOp *op) {
os << ";\n";
}
void ModuleEmitter::emitAffineMax(AffineMaxOp *op) {
template <typename OpType>
void ModuleEmitter::emitAffineMaxMin(OpType *op, const char *syntax) {
indent();
emitValue(op->getResult());
os << " = ";
@ -736,26 +755,7 @@ void ModuleEmitter::emitAffineMax(AffineMaxOp *op) {
AffineExprEmitter affineEmitter(state, affineMap.getNumDims(),
op->getOperands());
for (unsigned i = 0, e = affineMap.getNumResults() - 1; i < e; ++i) {
os << "max(";
}
affineEmitter.emitAffineExpr(affineMap.getResult(0));
for (auto &expr : llvm::drop_begin(affineMap.getResults(), 1)) {
os << ", ";
affineEmitter.emitAffineExpr(expr);
os << ")";
}
os << ";\n";
}
void ModuleEmitter::emitAffineMin(AffineMinOp *op) {
indent();
emitValue(op->getResult());
os << " = ";
auto affineMap = op->getAffineMap();
AffineExprEmitter affineEmitter(state, affineMap.getNumDims(),
op->getOperands());
for (unsigned i = 0, e = affineMap.getNumResults() - 1; i < e; ++i) {
os << "min(";
os << syntax << "(";
}
affineEmitter.emitAffineExpr(affineMap.getResult(0));
for (auto &expr : llvm::drop_begin(affineMap.getResults(), 1)) {
@ -798,6 +798,17 @@ void ModuleEmitter::emitAffineStore(AffineStoreOp *op) {
os << ";\n";
}
void ModuleEmitter::emitAffineVectorLoad(AffineVectorLoadOp *op) {
// TODO
return;
}
void ModuleEmitter::emitAffineVectorStore(AffineVectorStoreOp *op) {
// TODO
return;
}
// TODO: Support tensor/vector
void ModuleEmitter::emitAffineYield(AffineYieldOp *op) {
if (op->getNumOperands() == 0)
return;
@ -909,13 +920,13 @@ void ModuleEmitter::emitAlloc(OpType *op) {
return;
// Vivado HLS only supports static shape on-chip memory.
if (!op->getType().hasStaticShape())
if (!op->getType().hasStaticShape()) {
emitError(*op, "is unranked or has dynamic shape.");
return;
}
indent();
emitValue(op->getResult());
for (auto &shape : op->getType().getShape())
os << "[" << shape << "]";
emitArrayDecl<MemRefType>(op->getResult(), op->getType());
os << ";\n";
}
@ -947,23 +958,22 @@ void ModuleEmitter::emitStore(StoreOp *op) {
/// Tensor-related statement emitters.
void ModuleEmitter::emitTensorLoad(TensorLoadOp *op) {
if (!op->getType().hasStaticShape())
if (!op->getType().hasStaticShape()) {
emitError(*op, "is unranked or has dynamic shape.");
return;
}
auto tensorShape = op->getType().getShape();
// Declare a new tensor.
indent();
emitValue(op->getResult());
for (auto &shape : tensorShape)
os << "[" << shape << "]";
emitArrayDecl<TensorType>(op->getResult(), op->getType());
os << ";\n";
// Create a nested loop for loading tensor.
unsigned numDim = tensorShape.size();
unsigned numDim = op->getType().getShape().size();
for (unsigned i = 0; i < numDim; ++i) {
indent();
os << "for (int idx" << i << " = 0; ";
os << "idx" << i << " < " << tensorShape[i] << "; ";
os << "idx" << i << " < " << op->getType().getShape()[i] << "; ";
os << "++idx" << i << ") {\n";
addIndent();
@ -989,8 +999,10 @@ void ModuleEmitter::emitTensorLoad(TensorLoadOp *op) {
void ModuleEmitter::emitTensorStore(TensorStoreOp *op) {
auto tensorType = op->getOperand(0).getType().cast<TensorType>();
if (!tensorType.hasStaticShape())
if (!tensorType.hasStaticShape()) {
emitError(*op, "is unranked or has dynamic shape.");
return;
}
// Create a nested loop for storing tensor.
unsigned numDim = tensorType.getShape().size();
@ -1021,12 +1033,22 @@ void ModuleEmitter::emitTensorStore(TensorStoreOp *op) {
}
}
void ModuleEmitter::emitSplat(SplatOp *op) { return; }
void ModuleEmitter::emitSplat(SplatOp *op) {
// TODO
return;
}
void ModuleEmitter::emitExtractElement(ExtractElementOp *op) { return; }
void ModuleEmitter::emitExtractElement(ExtractElementOp *op) {
// TODO
return;
}
void ModuleEmitter::emitTensorFromElements(TensorFromElementsOp *op) { return; }
void ModuleEmitter::emitTensorFromElements(TensorFromElementsOp *op) {
// TODO
return;
}
// TODO: support vector.
void ModuleEmitter::emitDim(DimOp *op) {
if (auto constOp = dyn_cast<ConstantOp>(op->getOperand(1).getDefiningOp())) {
auto constVal = constOp.getValue().dyn_cast<IntegerAttr>().getInt();
@ -1059,6 +1081,7 @@ void ModuleEmitter::emitDim(DimOp *op) {
emitError(*op, "index is not a constant.");
}
// TODO: support vector.
void ModuleEmitter::emitRank(RankOp *op) {
if (auto memType = op->getOperand().getType().dyn_cast<MemRefType>()) {
if (memType.hasRank()) {
@ -1081,22 +1104,126 @@ void ModuleEmitter::emitRank(RankOp *op) {
}
/// Standard expression emitters.
void ModuleEmitter::emitBinary(Operation *op, const char *syntax) {
template <typename ArrayType>
void ModuleEmitter::emitEltWiseBinary(Operation *op, const char *syntax,
ArrayType type) {
if (!type.hasStaticShape()) {
emitError(op, "is unranked or has dynamic shape.");
return;
}
// Declare a new tensor.
indent();
emitValue(op->getResult(0));
for (auto &shape : type.getShape())
os << "[" << shape << "]";
os << ";\n";
// Create a nested loop for loading tensor.
unsigned numDim = type.getShape().size();
for (unsigned i = 0; i < numDim; ++i) {
indent();
os << "for (int idx" << i << " = 0; ";
os << "idx" << i << " < " << type.getShape()[i] << "; ";
os << "++idx" << i << ") {\n";
addIndent();
}
indent();
emitValue(op->getResult(0));
for (unsigned i = 0; i < numDim; ++i)
os << "[idx" << i << "]";
os << " = ";
emitValue(op->getOperand(0));
for (unsigned i = 0; i < numDim; ++i)
os << "[idx" << i << "]";
os << " " << syntax << " ";
emitValue(op->getOperand(1));
for (unsigned i = 0; i < numDim; ++i)
os << "[idx" << i << "]";
os << ";\n";
for (unsigned i = 0; i < numDim; ++i) {
reduceIndent();
indent();
os << "}\n";
}
}
void ModuleEmitter::emitBinary(Operation *op, const char *syntax) {
if (auto type = op->getResult(0).getType().dyn_cast<TensorType>())
emitEltWiseBinary<TensorType>(op, syntax, type);
else if (auto type = op->getResult(0).getType().dyn_cast<VectorType>())
emitEltWiseBinary<VectorType>(op, syntax, type);
else {
indent();
emitValue(op->getResult(0));
os << " = ";
emitValue(op->getOperand(0));
os << " " << syntax << " ";
emitValue(op->getOperand(1));
os << ";\n";
}
}
template <typename ArrayType>
void ModuleEmitter::emitEltWiseUnary(Operation *op, const char *syntax,
ArrayType type) {
if (!type.hasStaticShape()) {
emitError(op, "is unranked or has dynamic shape.");
return;
}
// Declare a new tensor.
indent();
emitValue(op->getResult(0));
for (auto &shape : type.getShape())
os << "[" << shape << "]";
os << ";\n";
// Create a nested loop for loading tensor.
unsigned numDim = type.getShape().size();
for (unsigned i = 0; i < numDim; ++i) {
indent();
os << "for (int idx" << i << " = 0; ";
os << "idx" << i << " < " << type.getShape()[i] << "; ";
os << "++idx" << i << ") {\n";
addIndent();
}
indent();
emitValue(op->getResult(0));
for (unsigned i = 0; i < numDim; ++i)
os << "[idx" << i << "]";
os << " = " << syntax << "(";
emitValue(op->getOperand(0));
for (unsigned i = 0; i < numDim; ++i)
os << "[idx" << i << "]";
os << ");\n";
for (unsigned i = 0; i < numDim; ++i) {
reduceIndent();
indent();
os << "}\n";
}
}
void ModuleEmitter::emitUnary(Operation *op, const char *syntax) {
indent();
emitValue(op->getResult(0));
os << " = " << syntax << "(";
emitValue(op->getOperand(0));
os << ");\n";
if (auto type = op->getResult(0).getType().dyn_cast<TensorType>())
emitEltWiseUnary<TensorType>(op, syntax, type);
else if (auto type = op->getResult(0).getType().dyn_cast<VectorType>())
emitEltWiseUnary<VectorType>(op, syntax, type);
else {
indent();
emitValue(op->getResult(0));
os << " = " << syntax << "(";
emitValue(op->getOperand(0));
os << ");\n";
}
}
/// Special operation emitters.
@ -1120,10 +1247,10 @@ void ModuleEmitter::emitSelect(SelectOp *op) {
os << ";\n";
}
template <typename ResultType>
void ModuleEmitter::emitCppArray(ConstantOp *op) {
template <typename ArrayType>
void ModuleEmitter::emitConstantArray(ConstantOp *op, ArrayType type) {
auto denseAttr = op->getValue().dyn_cast<DenseElementsAttr>();
auto elementType = op->getType().dyn_cast<ResultType>().getElementType();
auto elementType = type.getElementType();
os << "{";
unsigned elementIdx = 0;
if (elementType.isF32()) {
@ -1174,12 +1301,15 @@ void ModuleEmitter::emitConstant(ConstantOp *op) {
return;
} else if (constAttr.isa<DenseElementsAttr>()) {
indent();
emitValue(op->getResult());
os << " = ";
if (op->getType().isa<TensorType>())
emitCppArray<TensorType>(op);
else
emitCppArray<VectorType>(op);
if (auto type = op->getType().dyn_cast<TensorType>()) {
emitArrayDecl<TensorType>(op->getResult(), type);
os << " = ";
emitConstantArray<TensorType>(op, type);
} else if (auto type = op->getType().dyn_cast<VectorType>()) {
emitArrayDecl<VectorType>(op->getResult(), type);
os << " = ";
emitConstantArray<VectorType>(op, type);
}
os << ";\n";
} else
emitError(*op, "has unsupported constant type.");
@ -1201,12 +1331,12 @@ void ModuleEmitter::emitValue(Value val, bool isPtr) {
// Handle memref, tensor, and vector types.
auto valType = val.getType();
if (auto memType = valType.dyn_cast<MemRefType>())
valType = memType.getElementType();
else if (auto tensorType = valType.dyn_cast<TensorType>())
valType = tensorType.getElementType();
else if (auto vectorType = valType.dyn_cast<VectorType>())
valType = vectorType.getElementType();
if (auto type = valType.dyn_cast<MemRefType>())
valType = type.getElementType();
else if (auto type = valType.dyn_cast<TensorType>())
valType = type.getElementType();
else if (auto type = valType.dyn_cast<VectorType>())
valType = type.getElementType();
// Emit value type for declaring a new value.
switch (valType.getKind()) {
@ -1255,6 +1385,20 @@ void ModuleEmitter::emitBlock(Block &block) {
emitOperation(&op);
}
template <typename ArrayType>
void ModuleEmitter::emitArrayDecl(Value array, ArrayType type) {
// This indicates that the array has been declared before.
if (!getName(array).empty())
return;
if (type.hasStaticShape()) {
emitValue(array);
for (auto &shape : type.getShape())
os << "[" << shape << "]";
} else
emitValue(array, /*isPtr=*/true);
}
void ModuleEmitter::emitFunction(FuncOp func) {
if (func.getBlocks().size() != 1)
emitError(func, "has more than one basic blocks.");
@ -1267,10 +1411,15 @@ void ModuleEmitter::emitFunction(FuncOp func) {
unsigned argIdx = 0;
for (auto &arg : func.getArguments()) {
indent();
emitValue(arg);
if (auto memType = arg.getType().dyn_cast<MemRefType>())
for (auto &shape : memType.getShape())
os << "[" << shape << "]";
if (auto type = arg.getType().dyn_cast<MemRefType>())
emitArrayDecl<MemRefType>(arg, type);
else if (auto type = arg.getType().dyn_cast<TensorType>())
emitArrayDecl<TensorType>(arg, type);
else if (auto type = arg.getType().dyn_cast<VectorType>())
emitArrayDecl<VectorType>(arg, type);
else
emitValue(arg);
if (argIdx == func.getNumArguments() - 1 && func.getNumResults() == 0)
os << "\n";
else
@ -1283,14 +1432,15 @@ void ModuleEmitter::emitFunction(FuncOp func) {
unsigned resultIdx = 0;
for (auto result : funcReturn.getOperands()) {
indent();
if (auto memType = result.getType().dyn_cast<MemRefType>()) {
emitValue(result);
for (auto &shape : memType.getShape())
os << "[" << shape << "]";
} else {
if (auto type = result.getType().dyn_cast<MemRefType>())
emitArrayDecl<MemRefType>(result, type);
else if (auto type = result.getType().dyn_cast<TensorType>())
emitArrayDecl<TensorType>(result, type);
else if (auto type = result.getType().dyn_cast<VectorType>())
emitArrayDecl<VectorType>(result, type);
else
// In Vivado HLS, pointer type indicates an output scalar value.
emitValue(result, /*isPtr=*/true);
}
if (resultIdx == func.getNumResults() - 1)
os << "\n";

View File

@ -15,6 +15,9 @@ func @test_standard(%val1: i32, %val2: memref<16xi32>, %fval1: f32, %fval2: f32)
%tensor1 = tensor_load %val2 : memref<16xi32>
tensor_store %tensor1, %val2 : memref<16xi32>
%tensor2 = constant dense<[[23, 0], [-23, 0]]> : tensor<2x2xi32>
%tensor3 = addi %tensor0, %tensor2 : tensor<2x2xi32>
%c0 = constant 0 : index
%dim = dim %tensor1, %c0 : tensor<16xi32>
%rank = rank %tensor0 : tensor<2x2xi32>