[EmitHLSCpp] function signature handler; nameTable mechanism; correctly emit addi example

This commit is contained in:
Hanchen Ye 2020-08-29 14:56:24 -05:00
parent 395b70052c
commit 2115585a94
2 changed files with 119 additions and 67 deletions

View File

@ -16,44 +16,9 @@
#include "EmitHLSCpp.h"
using namespace std;
using namespace mlir;
//===------------------------------------------------------------*- C++ -*-===//
// Utils
//===----------------------------------------------------------------------===//
static SmallString<8> getTypeString(Type type, Operation *op) {
SmallString<8> typeString("unknown");
switch (type.getKind()) {
// Handle float types.
case StandardTypes::F16:
typeString = "float";
break;
case StandardTypes::F32:
typeString = "double";
break;
// Handle integer types.
case StandardTypes::Index:
typeString = "int";
break;
case StandardTypes::Integer: {
auto intType = type.cast<IntegerType>();
typeString = "ap_";
if (intType.getSignedness() == IntegerType::SignednessSemantics::Unsigned)
typeString += "u";
typeString += StringRef("int<" + std::to_string(intType.getWidth()) + ">");
break;
}
default:
op->emitError("has unsupported type.");
break;
} // switch (type.getKind())
return typeString;
}
//===----------------------------------------------------------------------===//
// Some Base Classes
//
@ -155,6 +120,10 @@ public:
};
} // namespace
//===----------------------------------------------------------------------===//
// Utils
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// ModuleEmitter Class Definition
//===----------------------------------------------------------------------===//
@ -170,17 +139,12 @@ public:
void emitModule(ModuleOp module);
private:
DenseMap<Value, SmallString<4>> nameTable;
DenseMap<Value, SmallString<8>> nameTable;
SmallString<4> getName(Value value) { return nameTable[value]; }
SmallString<4> addName(Value value) {
// Temporary naming rule.
SmallString<4> valueName("val");
valueName += StringRef(std::to_string(nameTable.size()));
SmallString<8> getName(Value val) { return nameTable[val]; }
SmallString<8> addName(Value val, bool isPtr);
nameTable[value] = valueName;
return valueName;
};
void emitValueDecl(Value val, bool isPtr);
void emitOperation(Operation *op);
void emitFunction(FuncOp func);
@ -198,6 +162,7 @@ public:
using HLSCppVisitorBase::visitOp;
bool visitOp(AddIOp op) { return emitter.emitBinaryExpr(op, "+"), true; }
bool visitOp(ReturnOp op) { return true; }
private:
ModuleEmitter &emitter;
@ -222,43 +187,126 @@ private:
// ModuleEmitter Class Implementation
//===----------------------------------------------------------------------===//
SmallString<8> ModuleEmitter::addName(Value val, bool isPtr = false) {
// Temporary naming rule.
SmallString<8> newName;
if (isPtr)
newName += "*";
newName += StringRef("val" + to_string(nameTable.size()));
auto valName = nameTable[val];
if (!valName.empty() && valName != newName)
return valName;
else {
nameTable[val] = newName;
return newName;
}
}
void ModuleEmitter::emitValueDecl(Value val, bool isPtr = false) {
// Value has been declared before.
if (!getName(val).empty()) {
os << getName(val);
return;
}
switch (val.getType().getKind()) {
// Handle float types.
case StandardTypes::F16:
os << "float ";
break;
case StandardTypes::F32:
os << "double ";
break;
// Handle integer types.
case StandardTypes::Index:
os << "int ";
break;
case StandardTypes::Integer: {
auto intType = val.getType().cast<IntegerType>();
os << "ap_";
if (intType.getSignedness() == IntegerType::SignednessSemantics::Unsigned)
os << "u";
os << "int<" << intType.getWidth() << "> ";
break;
}
default:
emitError(val.getDefiningOp(), "has unsupported type.");
break;
}
// Add the new value declaration to nameTable.
os << addName(val, isPtr);
return;
}
void ModuleEmitter::emitBinaryExpr(Operation *op, const char *syntax) {
indent();
emitValueDecl(op->getResult(0));
// Emit result type.
os << getTypeString(op->getResultTypes().front(), op) << " ";
// Emit result value name.
os << addName(op->getResult(0)) << " = ";
// Emit expression.
os << getName(op->getOperand(0));
// Emit expression. We are not folding sub-expressions for now.
os << " = " << getName(op->getOperand(0));
os << " " << syntax << " ";
os << getName(op->getOperand(1)) << ";\n";
}
void ModuleEmitter::emitOperation(Operation *op) {
ExprVisitor(*this).dispatchVisitor(op);
};
if (ExprVisitor(*this).dispatchVisitor(op))
return;
emitError(op, "can't be correctly emitted.");
}
void ModuleEmitter::emitFunction(FuncOp func) {
os << "void " << func.getName() << " (";
if (func.getBlocks().size() != 1)
emitError(func, "has more than one basic blocks.");
os << "void " << func.getName() << "(\n";
// TODO: handle function signature.
for (auto &arg : func.getArguments())
addName(arg);
os << ") {\n";
// Emit function signature.
addIndent();
// TODO: handle all operations.
if (func.getBlocks().size() != 1)
emitError(func, "has more than one blocks.");
// Handle input arguments.
unsigned argIdx = 0;
for (auto &arg : func.getArguments()) {
indent();
emitValueDecl(arg);
if (argIdx == func.getNumArguments() - 1 && func.getNumResults() == 0)
os << "\n";
else
os << ",\n";
argIdx += 1;
}
// Handle results.
if (auto funcReturn = dyn_cast<ReturnOp>(func.front().getTerminator())) {
unsigned resultIdx = 0;
for (auto result : funcReturn.getOperands()) {
indent();
emitValueDecl(result, /*isPtr=*/true);
if (resultIdx == func.getNumResults() - 1)
os << "\n";
else
os << ",\n";
resultIdx += 1;
}
} else {
emitError(func, "doesn't have return operation as terminator.");
}
reduceIndent();
os << ") {\n";
// Emit function body.
addIndent();
// Traverse all operations and emit them.
for (auto &op : func.front())
emitOperation(&op);
reduceIndent();
os << "}\n";
};
}
void ModuleEmitter::emitModule(ModuleOp module) {
os << R"XXX(

View File

@ -1,11 +1,15 @@
// RUN: hlsld-translate -emit-hlscpp %s | FileCheck %s
// CHECK: void test_standard () {
// CHECK: void test_standard(
// CHECK-NEXT: ap_int<32> val1,
// CHECK-NEXT: ap_int<32> val2,
// CHECK-NEXT: ap_int<32> *val3
// CHECK-NEXT: ) {
func @test_standard(%arg0: i32, %arg1: i32) -> (i32) {
// CHECK: ap_int<32> val2 = val0 + val1;
// CHECK: *val3 = val1 + val2;
%0 = addi %arg0, %arg1 : i32
return %0 : i32
// CHECK-NEXT: }
// CHECK: }
}