From 5188529c85947cb4025207b47063ec6599511f6d Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 30 Apr 2021 17:57:24 +0100 Subject: [PATCH 1/8] Added llvm submodule --- .gitmodules | 3 +++ llvm | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 llvm diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..7a38ca9 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "llvm"] + path = llvm + url = git@github.com:circt/llvm.git diff --git a/llvm b/llvm new file mode 160000 index 0000000..95bfd08 --- /dev/null +++ b/llvm @@ -0,0 +1 @@ +Subproject commit 95bfd0849f7fb8d0fe2c5d971ed97c219e1ccf72 From 2f9fcb8b4b740aed6c6d85b4f00f7765fde247ea Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 30 Apr 2021 18:34:36 +0100 Subject: [PATCH 2/8] updated doc --- README.md | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index a004907..1e64235 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,21 @@ This project aims to create a framework that ultimately converts an algorithm written in a high level language into an efficient hardware implementation. With multiple levels of intermediate representations (IRs), MLIR appears to be the ideal tool for exploring ways to optimize the eventual design at various levels of abstraction (e.g. various levels of parallelism). Our framework will be based on MLIR, it will incorporate a backend for high level synthesis (HLS) C/C++ code. However, the key contribution will be our parameterization and optimization of a tremendously large design space. ## Quick Start + +### 0. Download ScaleHLS and LLVM + +``` +$ git clone git@github.com:hanchenye/scalehls.git +$ cd scalehls +$ git submodule init +$ git submodule update +``` + ### 1. Install LLVM and MLIR -**IMPORTANT** This step assumes that you have cloned LLVM from (https://github.com/circt/llvm/tree/main) to `$LLVM_DIR` and checked out the `main` branch. To build LLVM and MLIR, run: +To build LLVM and MLIR, run: ```sh -$ mkdir $LLVM_DIR/build -$ cd $LLVM_DIR/build +$ mkdir $SCALEHLS_DIR/llvm/build +$ cd $SCALEHLS_DIR/llvm/build $ cmake -G Ninja ../llvm \ -DLLVM_ENABLE_PROJECTS="mlir" \ -DLLVM_TARGETS_TO_BUILD="X86;RISCV" \ @@ -23,8 +33,8 @@ This step assumes this repository is cloned to `$SCALEHLS_DIR`. To build and lau $ mkdir $SCALEHLS_DIR/build $ cd $SCALEHLS_DIR/build $ cmake -G Ninja .. \ - -DMLIR_DIR=$LLVM_DIR/build/lib/cmake/mlir \ - -DLLVM_DIR=$LLVM_DIR/build/lib/cmake/llvm \ + -DMLIR_DIR=$SCALEHLS_DIR/llvm/build/lib/cmake/mlir \ + -DLLVM_DIR=$SCALEHLS_DIR/llvm/build/lib/cmake/llvm \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DCMAKE_BUILD_TYPE=DEBUG $ ninja check-scalehls From 671b4a9991225b36e42e17774553bcdedb04e8d8 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Fri, 30 Apr 2021 13:21:20 -0500 Subject: [PATCH 3/8] fix bugs in new LLVM version --- include/scalehls/Dialect/HLSCpp/Attributes.td | 5 +++-- lib/Transforms/Loop/AffineLoopOrderOpt.cpp | 1 + tools/scalehls-opt/scalehls-opt.cpp | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/include/scalehls/Dialect/HLSCpp/Attributes.td b/include/scalehls/Dialect/HLSCpp/Attributes.td index cd3c332..b5c7812 100644 --- a/include/scalehls/Dialect/HLSCpp/Attributes.td +++ b/include/scalehls/Dialect/HLSCpp/Attributes.td @@ -7,8 +7,9 @@ #ifndef SCALEHLS_DIALECT_HLSCPP_ATTRIBUTES_TD #define SCALEHLS_DIALECT_HLSCPP_ATTRIBUTES_TD -class HLSCppAttr - : AttrDef; +class HLSCppAttr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef; def Resource : HLSCppAttr<"Resource"> { let summary = "Resource utilization information"; diff --git a/lib/Transforms/Loop/AffineLoopOrderOpt.cpp b/lib/Transforms/Loop/AffineLoopOrderOpt.cpp index 83c5da7..0b72c2d 100644 --- a/lib/Transforms/Loop/AffineLoopOrderOpt.cpp +++ b/lib/Transforms/Loop/AffineLoopOrderOpt.cpp @@ -9,6 +9,7 @@ #include "mlir/Transforms/LoopUtils.h" #include "scalehls/Transforms/Passes.h" #include "scalehls/Transforms/Utils.h" +#include "llvm/Support/Debug.h" #define DEBUG_TYPE "scalehls" diff --git a/tools/scalehls-opt/scalehls-opt.cpp b/tools/scalehls-opt/scalehls-opt.cpp index bcb5481..c2907f5 100644 --- a/tools/scalehls-opt/scalehls-opt.cpp +++ b/tools/scalehls-opt/scalehls-opt.cpp @@ -13,6 +13,6 @@ int main(int argc, char **argv) { mlir::scalehls::registerAllDialects(registry); mlir::scalehls::registerAllPasses(); - return mlir::failed( - mlir::MlirOptMain(argc, argv, "ScaleHLS Optimization Tool", registry)); + return mlir::failed(mlir::MlirOptMain( + argc, argv, "ScaleHLS Optimization Tool", registry, true)); } From 28c5594ce570535eb8d3bf94ef36e5e5f06325fd Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 30 Apr 2021 19:53:51 +0100 Subject: [PATCH 4/8] changed llvm branch to main --- README.md | 4 ++-- llvm | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1e64235..21b9910 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,8 @@ This step assumes this repository is cloned to `$SCALEHLS_DIR`. To build and lau $ mkdir $SCALEHLS_DIR/build $ cd $SCALEHLS_DIR/build $ cmake -G Ninja .. \ - -DMLIR_DIR=$SCALEHLS_DIR/llvm/build/lib/cmake/mlir \ - -DLLVM_DIR=$SCALEHLS_DIR/llvm/build/lib/cmake/llvm \ + -DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \ + -DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DCMAKE_BUILD_TYPE=DEBUG $ ninja check-scalehls diff --git a/llvm b/llvm index 95bfd08..49745f8 160000 --- a/llvm +++ b/llvm @@ -1 +1 @@ -Subproject commit 95bfd0849f7fb8d0fe2c5d971ed97c219e1ccf72 +Subproject commit 49745f87e61014ac2a9e93bcad1225c55695b9b7 From fd48c4cc5c8b344f678f8260b24f7de8da4e5b51 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 30 Apr 2021 23:23:12 +0100 Subject: [PATCH 5/8] Added C front end --- CMakeLists.txt | 5 + README.md | 5 +- test/CMakeLists.txt | 1 + test/lit.cfg.py | 1 + tools/CMakeLists.txt | 1 + tools/scalehls-clang/CMakeLists.txt | 20 + tools/scalehls-clang/scalehls-clang.cpp | 1423 +++++++++++++++++++++++ 7 files changed, 1455 insertions(+), 1 deletion(-) create mode 100644 tools/scalehls-clang/CMakeLists.txt create mode 100644 tools/scalehls-clang/scalehls-clang.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 88f1070..bbdbc65 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,10 @@ find_package(MLIR REQUIRED CONFIG) message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") +message(STATUS "Using ClangConfig.cmake in: ${CLANG_DIR}") + +set(Clang_DIR ${CLANG_DIR}) +find_package(Clang REQUIRED) set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) @@ -50,6 +54,7 @@ include(AddLLVM) include(AddMLIR) include(HandleLLVMOptions) +include_directories(${CLANG_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${PROJECT_SOURCE_DIR}/include) diff --git a/README.md b/README.md index 21b9910..bc6353b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ To build LLVM and MLIR, run: $ mkdir $SCALEHLS_DIR/llvm/build $ cd $SCALEHLS_DIR/llvm/build $ cmake -G Ninja ../llvm \ - -DLLVM_ENABLE_PROJECTS="mlir" \ + -DLLVM_ENABLE_PROJECTS="mlir;llvm;clang;clang-extra-tools" \ -DLLVM_TARGETS_TO_BUILD="X86;RISCV" \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DCMAKE_BUILD_TYPE=DEBUG @@ -35,6 +35,9 @@ $ cd $SCALEHLS_DIR/build $ cmake -G Ninja .. \ -DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \ -DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \ + -DCLANG_DIR=$PWD/../llvm/build/lib/cmake/clang \ + -DCMAKE_C_COMPILER=$PWD/../llvm/build/bin/clang \ + -DCMAKE_CXX_COMPILER=$PWD/../llvm/build/bin/clang++ \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DCMAKE_BUILD_TYPE=DEBUG $ ninja check-scalehls diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 584b939..9597402 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,6 +7,7 @@ configure_lit_site_cfg( set(SCALEHLS_TEST_DEPENDS FileCheck count not + scalehls-clang scalehls-opt scalehls-translate benchmark-gen diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 3f8004f..2bd2fd8 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -53,6 +53,7 @@ llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) tool_dirs = [config.scalehls_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir] tools = [ + 'scalehls-clang', 'scalehls-opt', 'scalehls-translate', 'benchmark-gen' diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 711e691..600b52b 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(scalehls-clang) add_subdirectory(scalehls-opt) add_subdirectory(scalehls-translate) add_subdirectory(benchmark-gen) diff --git a/tools/scalehls-clang/CMakeLists.txt b/tools/scalehls-clang/CMakeLists.txt new file mode 100644 index 0000000..634bbb6 --- /dev/null +++ b/tools/scalehls-clang/CMakeLists.txt @@ -0,0 +1,20 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + +#include_directories(../../llvm/clang/include/) +#include_directories(../../llvm/build/tools/clang/include) + +#find_package(Clang REQUIRED) + +add_llvm_tool(scalehls-clang + scalehls-clang.cpp + ) + +llvm_update_compile_flags(scalehls-clang) + +target_link_libraries(scalehls-clang + PRIVATE + clangFrontend + clangTooling + clangBasic + ) diff --git a/tools/scalehls-clang/scalehls-clang.cpp b/tools/scalehls-clang/scalehls-clang.cpp new file mode 100644 index 0000000..9b7d85a --- /dev/null +++ b/tools/scalehls-clang/scalehls-clang.cpp @@ -0,0 +1,1423 @@ +//===----------------------------------------------------------------------===// +// +// Copyright 2020-2021 The ScaleHLS Authors. +// +//===----------------------------------------------------------------------===// + +#include "clang/AST/ASTConsumer.h" +#include "clang/AST/Decl.h" +#include "clang/AST/Expr.h" +#include "clang/AST/Mangle.h" +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/Stmt.h" +#include "clang/Frontend/CompilerInstance.h" +#include "clang/Frontend/FrontendAction.h" +#include "clang/Tooling/CommonOptionsParser.h" +#include "clang/Tooling/Tooling.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" + +#include +#include +#include +#include +#include +#include +#include + +using namespace clang; + +typedef std::vector int_vec; +typedef std::vector str_vec; + +// ------------------------------------------- +// This tool emits the MLIR code from C code by traversing the Clang AST. +// ------------------------------------------- + +// Default mode configuration +// Disable full flow check - enabling debug mode will lower the MLIR through +// SCF, standard, handshake till HPX Ops. +bool debug = false; +// Output file name +std::string output = "output.mlir"; + +// Output file stream +std::ofstream mlirOut; +// Number of variables in the MLIR code including intermediate variables +int varCnt = 0; +// Number of intermediate variables +int unnameVarCnt = 0; +// Unary variables to be updated after statement +int_vec unaryUpdateList; + +// Print error messages +auto error = [](auto str) { + std::cout << "\033[0;31mError: " << str << "\033[0m" << std::endl; + assert(0); +}; + +// Print highlight messages, like warning +auto message = [](auto str) { + std::cout << "\033[0;33mWarning: " << str << "\033[0m" << std::endl; +}; + +// Translation unit +class MLIRGeneratorVisitor : public RecursiveASTVisitor { +public: + explicit MLIRGeneratorVisitor(ASTContext *Context) : Context(Context) {} + + virtual bool VisitFunctionDecl(FunctionDecl *func) { + // Currently ignore the function prototype + if (func->hasBody()) + funcGen(func); + return true; + } + +private: + struct variable { + int ID; + // Original variable name in string + std::string name; + // Variable Declaration + ValueDecl *VarDecl; + QualType type; + }; + typedef std::vector var_vec; + + ASTContext *Context; + // Variable list + var_vec vars; + + bool isIndex(QualType type) { return type == QualType(); } + QualType getIndexType(void) { return QualType(); } + + // Extract the element type for an array + QualType getArrayElementType(const ConstantArrayType *Type) { + auto arrayType = dyn_cast(Type); + auto elementType = arrayType; + // Extract all the dimensions + do { + elementType = arrayType; + arrayType = dyn_cast( + arrayType->getArrayElementTypeNoTypeQual()); + } while (arrayType); + return elementType->getArrayElementTypeNoTypeQual() + ->getLocallyUnqualifiedSingleStepDesugaredType(); + } + + // Extract a list of indices for an array access + void getArrayIdxExpr(ArraySubscriptExpr *arrayExpr, int_vec *idxList, + std::string *exprBuff) { + auto expr = arrayExpr; + do { + auto idxExpr = expr->getIdx(); + assert(idxExpr); + auto idx = castToIndex(exprGen(idxExpr, -1, exprBuff), exprBuff); + idxList->push_back(idx); + expr = dyn_cast( + dyn_cast(expr->getBase())->getSubExpr()); + } while (expr); + } + + // Extract the variable ID of the array label + int getArrayLabel(ArraySubscriptExpr *expr) { + auto subExpr = dyn_cast(expr->getBase())->getSubExpr(); + while (auto arrayExpr = dyn_cast(subExpr)) + subExpr = dyn_cast(arrayExpr->getBase())->getSubExpr(); + assert(subExpr); + auto declRefExpr = dyn_cast(subExpr); + assert(declRefExpr && "Unsupported array label format"); + auto label = dyn_cast(declRefExpr->getDecl()); + assert(label && "Unsupported array label format"); + return getVarID(label); + } + + // Print the variable type to equivalent MLIR representation in string + std::string typePrint(QualType Type) { + // Customised type - index + if (isIndex(Type)) + return "index"; + else if (Type->isBooleanType()) + return "i1"; + else if (Type->isIntegerType()) + return "i32"; // i" + std::to_string(Context->getTypeSize(Type)); + else if (Type->isFloatingType()) + return "f64"; + else if (Type->isPointerType()) + return typePrint(Type->getPointeeType()); + else if (Type->isConstantArrayType()) { + std::string arrayDimension = "memref<"; + auto arrayType = dyn_cast(Type); + auto elementType = getArrayElementType(arrayType); + // Extract all the dimensions + do { + arrayDimension += + std::to_string(arrayType->getSize().getSExtValue()) + "x"; + arrayType = dyn_cast( + arrayType->getArrayElementTypeNoTypeQual()); + } while (arrayType); + arrayDimension += typePrint(elementType) + ">"; + return arrayDimension; + } else if (!Type->isVoidType()) + error("Undefined type found in typePrint: " + + Type->getPointeeType().getAsString()); + return ""; + } + std::string typePrint(int i) { return typePrint(getVarType(i)); } + + // Return the name of the variable as string + std::string getVarName(int i) { + if (i < 0 || i >= varCnt) + error("getVarName: Invalid var ID: " + std::to_string(i)); + return vars[i]->name; + } + + // Return the ID of the variable + int getVarID(ValueDecl *decl) { + assert(decl && "Unknown variable referenced"); + for (auto &var : vars) { + if (var->VarDecl && var->VarDecl == decl) + return var->ID; + } + error("Undefined variable: " + decl->getNameAsString()); + return -1; + } + + // Check whether a variable has been declared at the check point + bool hasBeenDeclared(ValueDecl *decl) { + for (auto &var : vars) { + if (var->VarDecl == decl) + return true; + } + return false; + } + + // Return the type of the variable + QualType getVarType(int i) { + if (i < 0 || i >= varCnt) + error("getVarType: Invalid var ID: " + std::to_string(i)); + return vars[i]->type; + } + + // Set the name of a variable to the given string + void setVarName(int i, std::string name) { + if (i < 0 || i >= varCnt) + error("Invalid var ID: " + std::to_string(i)); + vars[i]->name = name; + } + + // Increment variable count for a named variable + std::string incrName(std::string name) { + int strLoc = name.rfind("."); + return (strLoc != -1) + ? name.substr(0, strLoc + 1) + + std::to_string(1 + std::stoi(name.substr(strLoc + 1))) + : name + ".0"; + } + // Create a copy of the variable for the updated value + std::string incrVarName(int i) { + assert(i >= 0 && i < varCnt); + auto name = vars[i]->name; + if (vars[i]->VarDecl == nullptr) + message("May have an unnessary operation on " + name); + auto newName = incrName(name); + setVarName(i, newName); + return newName; + } + std::string incrVarName(ValueDecl *decl) { + return incrVarName(getVarID(decl)); + } + + // Add a Variable to the variable list + int addVar(ValueDecl *decl, QualType type) { + variable *var = new variable; + var->ID = varCnt; + var->VarDecl = decl; + if (decl == nullptr) { + var->name = "%" + std::to_string(unnameVarCnt); + unnameVarCnt++; + } else + var->name = "%" + decl->getNameAsString(); + var->type = type; + vars.push_back(var); + varCnt++; + return var->ID; + } + int addVar(ValueDecl *decl) { + assert(decl != nullptr); + return addVar(decl, decl->getType()); + } + + // Generate a constant value + int constIntGen(int value, QualType ty, std::string *buff) { + assert((isIndex(ty) || ty->isIntegerType()) && + "Generating a const int as a non-int type"); + auto tmp = addVar(nullptr, ty); + *buff += getVarName(tmp) + " = constant " + std::to_string(value) + " : " + + typePrint(ty) + "\n"; + return tmp; + } + + // Print the binary operator in MLIR + std::string binaryOpPrint(BinaryOperator *BinaryOp, QualType typeIn) { + // typeIn == nullptr is Index + std::string isInt = (isIndex(typeIn) || typeIn->isIntegerType()) + ? "i" + : ((typeIn->isFloatingType()) ? "f" : "x"); + assert(isInt != "x"); + std::string sign = (isIndex(typeIn) || typeIn->isSignedIntegerType()) + ? "s" + : (typeIn->isUnsignedIntegerType() ? "u" : "x"); + std::string signLong = + (isIndex(typeIn) || typeIn->isSignedIntegerType()) + ? "_signed" + : (typeIn->isUnsignedIntegerType() ? "_unsigned" : "x"); + + switch ( + BinaryOp + ->getOpcode()) { // jc: + // https://github.com/mull-project/mull/blob/3c1ad3b8ca428b816a3e3e5bd36568afd4090260/lib/JunkDetection/CXX/CXXJunkDetector.cpp + case clang::BinaryOperator::Opcode::BO_Add: + return "add" + isInt; + break; + case clang::BinaryOperator::Opcode::BO_AddAssign: + return "add" + isInt; + break; + case clang::BinaryOperator::Opcode::BO_Rem: + return (typeIn->isIntegerType()) ? "remi" + signLong : "remf"; + break; + case clang::BinaryOperator::Opcode::BO_Sub: + return "sub" + isInt; + break; + case clang::BinaryOperator::Opcode::BO_SubAssign: + return "sub" + isInt; + break; + case clang::BinaryOperator::Opcode::BO_Mul: + return "mul" + isInt; + break; + case clang::BinaryOperator::Opcode::BO_MulAssign: + return "mul" + isInt; + break; + case clang::BinaryOperator::Opcode::BO_Div: + return (typeIn->isIntegerType()) ? "divi" + signLong : "divf"; + break; + case clang::BinaryOperator::Opcode::BO_DivAssign: + return (typeIn->isIntegerType()) ? "divi" + signLong : "divf"; + break; + case clang::BinaryOperator::Opcode::BO_Or: + assert(typeIn->isIntegerType()); + return "or"; + break; + case clang::BinaryOperator::Opcode::BO_LOr: + assert(typeIn->isIntegerType()); + return "or"; + break; + case clang::BinaryOperator::Opcode::BO_And: + assert(typeIn->isIntegerType()); + return "and"; + break; + case clang::BinaryOperator::Opcode::BO_LAnd: + assert(typeIn->isIntegerType()); + return "and"; + break; + case clang::BinaryOperator::Opcode::BO_LT: + return "cmp" + isInt + " \"" + sign + "lt\","; + break; + case clang::BinaryOperator::Opcode::BO_LE: + return "cmp" + isInt + " \"" + sign + "le\","; + break; + case clang::BinaryOperator::Opcode::BO_GT: + return "cmp" + isInt + " \"" + sign + "gt\","; + break; + case clang::BinaryOperator::Opcode::BO_GE: + return "cmp" + isInt + " \"" + sign + "ge\","; + break; + case clang::BinaryOperator::Opcode::BO_EQ: + return "cmp" + isInt + " \"eq\","; + break; + case clang::BinaryOperator::Opcode::BO_NE: + return "cmp" + isInt + " \"ne\","; + break; + default: + error("Undefined binaryOperator in binaryOpPrint: " + + std::string(BinaryOp->getOpcodeStr())); + } + return ""; + } + + // Generate the binary expression from Clang AST + int binaryOperatorGen(clang::Stmt *Stmt, int resultVar, + std::string *exprBuff) { + auto BinaryOp = dyn_cast(Stmt); + auto result = -1; + auto operand = Stmt->child_begin(); + + if (BinaryOp->getOpcodeStr() == "=") { + if (operand->getStmtClass() == Stmt::StmtClass::DeclRefExprClass) { + // Assign to a register + auto a = dyn_cast(*operand)->getDecl(); + result = + exprGen(*(++operand), getVarID(dyn_cast(a)), exprBuff); + } else if (operand->getStmtClass() == + Stmt::StmtClass::ArraySubscriptExprClass) { + auto arrayElement = dyn_cast(*operand); + result = exprGen(*(++operand), -1, exprBuff); + assert(arrayElement); + assert(result != -1); + storeGen(result, arrayElement, exprBuff); + } else + error("Unknown assignment in binaryOperatorGen"); + } else { // An intermediate binary expression + int var0, var1; + var0 = exprGen(*operand, -1, exprBuff); + var1 = exprGen(*(++operand), -1, exprBuff); + auto typeIn = getVarType(var0); + if (typeIn != getVarType(var1)) + typeIn = typeCast(&var0, &var1, exprBuff); + auto op = binaryOpPrint(BinaryOp, typeIn); + if (resultVar != -1) { + auto buff = " = " + op + " " + getVarName(var0) + ", " + + getVarName(var1) + " : " + typePrint(typeIn) + "\n"; + assert(getVarType(resultVar) == typeIn && + "Cast for binary op not supported"); + *exprBuff += incrVarName(resultVar) + buff; + result = resultVar; + } else { + result = addVar(nullptr, BinaryOp->getType()); + *exprBuff += getVarName(result) + " = " + op + " " + getVarName(var0) + + ", " + getVarName(var1) + " : " + typePrint(typeIn) + "\n"; + } + assert(!BinaryOp->isComparisonOp() || + getVarType(result)->isBooleanType()); + } + return result; + } + + // Return the precision of the type, which can be used to decide which + // variable needs to be cast + int getTypePrecision(QualType Type) { + if (isIndex(Type)) + return 0; // least precision + else if (Type->isIntegerType()) + return 1; + else if (Type->isFloatingType()) + return 2; + else + error("Undefined type found in getTypePrecision: " + + Type->getPointeeType().getAsString()); + return -1; + } + + // Cast a variable from a type to another + QualType typeCast(int *var0, int *var1, std::string *exprBuff) { + auto v0 = *var0; + auto v1 = *var1; + auto ty0 = getVarType(v0); + auto ty1 = getVarType(v1); + auto p0 = getTypePrecision(ty0); + auto p1 = getTypePrecision(ty1); + + if (p0 < p1) { + *var0 = castTo(v0, p0 * 10 + p1, ty1, exprBuff); + return ty1; + } else { + *var1 = castTo(v1, p1 * 10 + p0, ty0, exprBuff); + return ty0; + } + } + + // Cast type to the high-precision type based on the given mode + int castTo(int var, int mode, QualType ty, std::string *exprBuff) { + int res; + switch (mode) { + case 1: // index2int + res = addVar(nullptr, ty); + *exprBuff += getVarName(res) + " = index_cast " + getVarName(var) + + " : index to i32\n"; + return res; + break; + case 10: // int2index + res = addVar(nullptr, ty); + *exprBuff += getVarName(res) + " = index_cast " + getVarName(var) + + " : i32 to index\n"; + return res; + break; + case 2: // index2float + error("Undefined type conversion in castTo."); + break; + case 12: // int2float + error("Undefined type conversion in castTo."); + break; + default: + error("Undefined type conversion in castTo."); + } + return -1; + } + + // Generate storeOp in MLIR + std::string addStore(int src, int dst, int_vec *index) { + assert(index->size() > 0 && "No index for store op"); + + std::string buffer; + auto arrayType = getVarType(dst); + std::string buff, args; + for (auto &i : *index) + args += getVarName(castToIndex(i, &buff)) + ", "; + args.pop_back(); + args.pop_back(); + args += "] : " + typePrint(arrayType) + "\n"; + buff += "store " + getVarName(src) + ", " + getVarName(dst) + "[" + args; + return buff; + } + + // Generate loadOp in MLIR + int addLoad(int array, int resultVar, std::string *buff, int_vec *index) { + assert(index->size() > 0 && "No index for load op"); + auto type = getVarType(array); + auto arrayType = dyn_cast(type); + assert(arrayType); + auto elementType = getArrayElementType(arrayType); + int res; + std::string resName; + if (resultVar == -1) { + res = addVar(nullptr, elementType); + resName = getVarName(res); + } else { + res = resultVar; + resName = incrVarName(res); + } + std::string args = resName + " = load " + getVarName(array) + "["; + for (auto &i : *index) + args += getVarName(castToIndex(i, buff)) + ", "; + args.pop_back(); + args.pop_back(); + args += "] : " + typePrint(type) + "\n"; + *buff += args; + return res; + } + + // Generate store op + void storeGen(int resultVar, ArraySubscriptExpr *expr, + std::string *exprBuff) { + int_vec idxList; + getArrayIdxExpr(expr, &idxList, exprBuff); + auto array = getArrayLabel(expr); + *exprBuff += addStore(resultVar, array, &idxList); + } + + // Generate load op + int loadGen(ArraySubscriptExpr *expr, int resultVar, std::string *exprBuff) { + int_vec idxList; + getArrayIdxExpr(expr, &idxList, exprBuff); + auto array = getArrayLabel(expr); + int res = addLoad(array, resultVar, exprBuff, &idxList); + return res; + } + + // Generate unary expression + int unaryExprGen(clang::Stmt *Stmt, int resultVar, std::string *exprBuff) { + UnaryOperator *unaryOp = dyn_cast(Stmt); + auto res = -1; + auto operand = unaryOp->getSubExpr(); + auto var = exprGen(operand, -1, exprBuff); + auto typeIn = getVarType(var); + std::string buff; + std::string isInt = (typeIn->isIntegerType()) + ? "i" + : ((typeIn->isFloatingType()) ? "f" : "x"); + + switch (unaryOp->getOpcode()) { + case clang::UnaryOperator::Opcode::UO_PreInc: + assert(typeIn->isIntegerType() && + "unary op \"++\" not supported for floating types"); + buff = " = addi " + getVarName(var) + ", " + + getVarName(constIntGen(1, typeIn, exprBuff)) + " : " + + typePrint(typeIn) + "\n"; + *exprBuff += incrVarName(var) + buff; + if (resultVar == -1) + res = var; + else { + res = addVar(nullptr, typeIn); + // Directly assignment is not legal in mlir. Instead use +0 here + *exprBuff += getVarName(res) + " = addi " + getVarName(var) + ", " + + getVarName(constIntGen(0, typeIn, exprBuff)) + " : " + + typePrint(typeIn) + "\n"; + } + break; + case clang::UnaryOperator::Opcode::UO_PostInc: + assert(typeIn->isIntegerType() && + "unary op \"++\" not supported for floating types"); + *exprBuff += incrName(getVarName(var)) + " = addi " + getVarName(var) + + ", " + getVarName(constIntGen(1, typeIn, exprBuff)) + " : " + + typePrint(typeIn) + "\n"; + assert(std::find(unaryUpdateList.begin(), unaryUpdateList.end(), var) == + unaryUpdateList.end() && + "Multiple unary ops for one variable in a single statement is not " + "supported"); + unaryUpdateList.push_back(var); + res = var; + break; + case clang::UnaryOperator::Opcode::UO_PreDec: + assert(typeIn->isIntegerType() && + "unary op \"--\" not supported for floating types"); + buff = " = subi " + getVarName(var) + ", " + + getVarName(constIntGen(1, typeIn, exprBuff)) + " : " + + typePrint(typeIn) + "\n"; + *exprBuff += incrVarName(var) + buff; + if (resultVar == -1) + res = var; + else { + res = addVar(nullptr, typeIn); + // Directly assignment is not legal in mlir. Instead use +0 here + *exprBuff += getVarName(res) + " = addi " + getVarName(var) + ", " + + getVarName(constIntGen(0, typeIn, exprBuff)) + " : " + + typePrint(typeIn) + "\n"; + } + error("Unsupported UnaryOperator in unaryExprGen: " + + std::string(unaryOp->getOpcodeStr(unaryOp->getOpcode()))); + break; + case clang::UnaryOperator::Opcode::UO_PostDec: + assert(typeIn->isIntegerType() && + "unary op \"--\" not supported for floating types"); + *exprBuff += incrName(getVarName(var)) + " = subi " + getVarName(var) + + ", " + getVarName(constIntGen(1, typeIn, exprBuff)) + " : " + + typePrint(typeIn) + "\n"; + assert(std::find(unaryUpdateList.begin(), unaryUpdateList.end(), var) == + unaryUpdateList.end() && + "Multiple unary ops for one variable in a single statement is not " + "supported"); + unaryUpdateList.push_back(var); + res = var; + break; + case clang::UnaryOperator::Opcode::UO_Not: + error("Unsupported UnaryOperator in unaryExprGen: " + + std::string(unaryOp->getOpcodeStr(unaryOp->getOpcode()))); + break; + case clang::UnaryOperator::Opcode::UO_LNot: + error("Unsupported UnaryOperator in unaryExprGen: " + + std::string(unaryOp->getOpcodeStr(unaryOp->getOpcode()))); + break; + case clang::UnaryOperator::Opcode::UO_Minus: + if (resultVar == -1) { + res = addVar(nullptr, typeIn); + *exprBuff += getVarName(res) + " = sub" + isInt + " " + + getVarName(constIntGen(0, typeIn, exprBuff)) + ", " + + getVarName(var) + " : " + typePrint(typeIn) + "\n"; + } else { + res = resultVar; + *exprBuff += incrVarName(res) + " = sub" + isInt + " " + + getVarName(constIntGen(0, typeIn, exprBuff)) + ", " + + getVarName(var) + " : " + typePrint(typeIn) + "\n"; + } + break; + default: + error("Undefined UnaryOperator in unaryExprGen: " + + std::string(unaryOp->getOpcodeStr(unaryOp->getOpcode()))); + } + return res; + } + + // Cast an integer to index + int castToIndex(int i, std::string *exprBuff) { + if (!isIndex(getVarType(i))) { + auto res = addVar(nullptr, getIndexType()); + *exprBuff += getVarName(res) + " = index_cast " + getVarName(i) + " : " + + typePrint(i) + " to index\n"; + return res; + } else + return i; + } + + // Generate integer/floating literal expression + int literalExprGen(Expr *expr, int result, std::string *exprBuff) { + auto intLiteral = dyn_cast(expr); + auto floatLiteral = dyn_cast(expr); + + llvm::SmallString<16> value; + if (intLiteral) + value = intLiteral->getValue().toString(10, true); + else if (floatLiteral) + floatLiteral->getValue().toString(value); + else + error("Unsupported literal type"); + + if (result == -1) { + auto rhsType = expr->getType(); + auto varID = addVar(nullptr, rhsType); + *exprBuff += getVarName(varID) + " = constant " + std::string(value) + + " : " + typePrint(rhsType) + "\n"; + return varID; + } else { + *exprBuff += incrVarName(result) + " = constant " + std::string(value) + + " : " + typePrint(result) + "\n"; + return result; + } + } + + int findCastMode(QualType dst, QualType src) { + auto dstType = typePrint(dst); + auto srcType = typePrint(src); + if (srcType == "i32" && dstType == "index") // int2index + return 10; + else if (srcType == "index" && dstType == "i32") // index2int + return 1; + else + error("Unsupported conversion : " + srcType + " -> " + dstType); + return -1; + } + + // Generate Implicit Cast Expression + int implicitCastExprGen(ImplicitCastExpr *expr, int resultVar, + std::string *exprBuff) { + auto result = exprGen(expr->getSubExpr(), resultVar, exprBuff); + auto dstType = expr->getType(); + auto srcType = getVarType(result); + if (dstType == srcType) + return result; + else { + auto mode = findCastMode(dstType, srcType); + return castTo(result, mode, dstType, exprBuff); + } + } + + // Generate conditional operator + // TODO: use SCF.if? + int conditionalOperatorGen(ConditionalOperator *expr, int resultVar, + std::string *exprBuff) { + auto condExpr = exprGen(expr->getCond(), -1, exprBuff); + auto trueExpr = exprGen(expr->getTrueExpr(), -1, exprBuff); + auto falseExpr = exprGen(expr->getFalseExpr(), -1, exprBuff); + auto ty = getVarType(trueExpr); + assert(getVarType(falseExpr) == ty); + assert(getVarType(condExpr)->isBooleanType()); + if (resultVar == -1) { + auto result = addVar(nullptr, ty); + *exprBuff += getVarName(result) + " = select " + getVarName(condExpr) + + ", " + getVarName(trueExpr) + ", " + getVarName(falseExpr) + + " : " + typePrint(ty); + return result; + } else { + *exprBuff += incrVarName(resultVar) + " = select " + + getVarName(condExpr) + ", " + getVarName(trueExpr) + ", " + + getVarName(falseExpr) + " : " + typePrint(ty); + return resultVar; + } + } + + // Generate a generic expression + int exprGen(clang::Stmt *Stmt, int resultVar, std::string *exprBuff) { + if (debug) + std::cout << "Found stmt: " << Stmt->getStmtClassName() << "\n"; + auto Type = Stmt->getStmtClass(); + auto result = -1; + switch (Type) { + case Stmt::StmtClass::ImplicitCastExprClass: + result = implicitCastExprGen(dyn_cast(Stmt), resultVar, + exprBuff); + break; + case Stmt::StmtClass::ArraySubscriptExprClass: + result = loadGen(dyn_cast(Stmt), resultVar, exprBuff); + break; + case Stmt::StmtClass::BinaryOperatorClass: + result = binaryOperatorGen(Stmt, resultVar, exprBuff); + break; + case Stmt::StmtClass::ParenExprClass: + result = + exprGen(dyn_cast(Stmt)->getSubExpr(), resultVar, exprBuff); + break; + case Stmt::StmtClass::UnaryOperatorClass: + result = unaryExprGen(Stmt, resultVar, exprBuff); + break; + case Stmt::StmtClass::DeclRefExprClass: + if (resultVar == -1) + result = + getVarID(dyn_cast(dyn_cast(Stmt)->getDecl())); + else { + auto var = + getVarID(dyn_cast(dyn_cast(Stmt)->getDecl())); + auto ty = getVarType(var); + std::string isInt = (isIndex(ty) || ty->isIntegerType()) + ? "i" + : ((ty->isFloatingType()) ? "f" : "x"); + assert(isInt != "x"); + *exprBuff += incrVarName(resultVar) + " = add" + isInt + " " + + getVarName(constIntGen(0, ty, exprBuff)) + ", " + + getVarName(var) + " : " + typePrint(ty) + "\n"; + result = resultVar; + } + break; + case Stmt::StmtClass::IntegerLiteralClass: + result = literalExprGen(dyn_cast(Stmt), resultVar, exprBuff); + break; + case Stmt::StmtClass::FloatingLiteralClass: + result = literalExprGen(dyn_cast(Stmt), resultVar, exprBuff); + break; + case Stmt::StmtClass::CStyleCastExprClass: + result = exprGen(dyn_cast(Stmt)->getSubExpr(), resultVar, + exprBuff); + break; + case Stmt::StmtClass::ConditionalOperatorClass: + result = conditionalOperatorGen(dyn_cast(Stmt), + resultVar, exprBuff); + break; + default: + error("Undefined expression in exprGen: " + + std::string(Stmt->getStmtClassName())); + } + if (result == -1) + error("Translation failed for " + std::string(Stmt->getStmtClassName())); + return result; + } + + // Extract the variable list for the yield op + // Stmt - body of the statement + // yields - list of variables to yeild + void getYieldVars(clang::Stmt *Stmt, int_vec *yields) { + if (Stmt) { + if (auto forStmt = dyn_cast(Stmt)) { + getYieldVars(forStmt->getInit(), yields); + getYieldVars(forStmt->getCond(), yields); + getYieldVars(forStmt->getInc(), yields); + getYieldVars(forStmt->getBody(), yields); + } else if (auto ifStmt = dyn_cast(Stmt)) { + getYieldVars(ifStmt->getCond(), yields); + getYieldVars(ifStmt->getThen(), yields); + getYieldVars(ifStmt->getElse(), yields); + } else { + for (auto st : Stmt->children()) { + if (st->getStmtClass() == Stmt::StmtClass::DeclRefExprClass) { + // If a pre-defined variable is referred + auto var = dyn_cast(dyn_cast(st)->getDecl()); + if (!isa(Stmt) && hasBeenDeclared(var)) { + // If it is a store, then add to the yield list + auto varID = getVarID(var); + if (std::find(yields->begin(), yields->end(), varID) == + yields->end()) + yields->push_back(varID); + } + } else + getYieldVars(st, yields); + } + } + } + } + + // Generate if statements + void ifGen(IfStmt *Stmt, std::string *exprBuff) { + Expr *cond = Stmt->getCond(); + std::string condition = getVarName(exprGen( + &*cond, -1, exprBuff)); // JC: conditional variable not supported now + // yield + int_vec yieldsTrue, yieldsElse, yields; + clang::Stmt *scfTrue = Stmt->getThen(); + if (scfTrue) + getYieldVars(scfTrue, &yieldsTrue); + clang::Stmt *scfFalse = Stmt->getElse(); + if (scfFalse) + getYieldVars(scfFalse, &yieldsElse); + + bool toYield = (yieldsTrue.size() + yieldsElse.size() > 0); + str_vec yieldsVars; + std::string buffer; + std::string yieldTypes; + if (toYield) { + // merge two yield lists + yields = yieldsTrue; + for (auto y : yieldsElse) { + if (std::find(yields.begin(), yields.end(), y) == yields.end()) + yields.push_back(y); + } + // store var names in the parent region + for (auto y : yields) + yieldsVars.push_back(getVarName(y)); + // skip result vars to print if header + buffer += " = scf.if " + condition + " -> ("; + for (auto y : yields) + yieldTypes += typePrint(y) + ", "; + yieldTypes.pop_back(); + yieldTypes.pop_back(); + buffer += yieldTypes + ") {\n"; + } else + buffer += "scf.if " + condition + " {\n"; + if (scfTrue) + stmtGen(scfTrue, &buffer); + if (toYield) { // true branch + buffer += "scf.yield "; + for (auto y : yields) + buffer += getVarName(y) + ", "; + buffer.pop_back(); + buffer.pop_back(); + buffer += " : " + yieldTypes + "\n"; + for (unsigned long i = 0; i < yields.size(); i++) + setVarName(yields[i], yieldsVars[i]); + } + buffer += "} else {\n"; + if (scfFalse) + stmtGen(scfFalse, &buffer); + if (toYield) { // false branch + buffer += "scf.yield "; + for (auto y : yields) + buffer += getVarName(y) + ", "; + buffer.pop_back(); + buffer.pop_back(); + buffer += " : " + yieldTypes + "\n"; + for (unsigned long i = 0; i < yields.size(); i++) + setVarName(yields[i], yieldsVars[i]); + } + buffer += "}\n"; + std::string resultVars; + if (toYield) { // now print the result variables + for (auto y : yields) + resultVars += incrVarName(y) + ", "; + resultVars.pop_back(); + resultVars.pop_back(); + } + *exprBuff += resultVars + buffer; + } + + // Return whether the loop has jump + bool hasJump(ForStmt *Stmt) { + clang::Stmt *loopBody = Stmt->getBody(); + return hasGotoOrContinueOrBreakOrReturn(loopBody); + } + + // Return whether the loop has jump + bool hasGotoOrContinueOrBreakOrReturn(Stmt *Stmt) { + if (Stmt) + return 0; + AsmStmt::StmtClass Type = Stmt->getStmtClass(); + ForStmt *forStmt; + IfStmt *ifStmt; + int s = 0; + switch ( + Type) { // JC: class ref at + // https://github.com/silent-silence/CodeRefactor/blob/56462375195d26d0620cc454cb56f72abab760bb/AST/ASTContext.h + case Stmt::StmtClass::GotoStmtClass: + return true; + case Stmt::StmtClass::IndirectGotoStmtClass: + return true; + case Stmt::StmtClass::ContinueStmtClass: + return true; + case Stmt::StmtClass::BreakStmtClass: + return true; + case Stmt::StmtClass::ReturnStmtClass: + return true; + case Stmt::StmtClass::ForStmtClass: + forStmt = dyn_cast(Stmt); + s += hasGotoOrContinueOrBreakOrReturn(forStmt->getInit()); + s += hasGotoOrContinueOrBreakOrReturn(forStmt->getCond()); + s += hasGotoOrContinueOrBreakOrReturn(forStmt->getInc()); + s += hasGotoOrContinueOrBreakOrReturn(forStmt->getBody()); + break; + case Stmt::StmtClass::IfStmtClass: + ifStmt = dyn_cast(Stmt); + s += hasGotoOrContinueOrBreakOrReturn(ifStmt->getCond()); + s += hasGotoOrContinueOrBreakOrReturn(ifStmt->getThen()); + s += hasGotoOrContinueOrBreakOrReturn(ifStmt->getElse()); + break; + default: + for (auto &st : Stmt->children()) + s += hasGotoOrContinueOrBreakOrReturn(st); + } + return (s > 0); + } + + // Check whether a loop can be transformed into a SCF.for. If it can, generate + // SCF.For op + bool trySCFForGen(ForStmt *Stmt, std::string *exprBuff) { + + // Init check - check if the loop iterator is declared within the loop + auto init = Stmt->getInit(); + auto initDecl = dyn_cast(init); + if (!init || !initDecl->isSingleDecl()) + return false; + auto indDecl = dyn_cast(initDecl->getSingleDecl()); + assert(indDecl && indDecl->hasInit() && "Unsupported init for scf.for op"); + // indvar - loop iterator + VarDecl *indVar = dyn_cast(indDecl); + + // Step check - has step & step has to be a positive constant + bool stepHold = false; + auto inc = Stmt->getInc(); + UnaryOperator *unaryOp; + CompoundAssignOperator *compAssign; + int step; + if (inc) { + unaryOp = dyn_cast(inc); + compAssign = dyn_cast(inc); + // step has to be in the forms like i++ or i+=c + if (unaryOp) { + auto varDecl = dyn_cast(unaryOp->getSubExpr()); + if (varDecl) { + if (dyn_cast(varDecl->getDecl()) == indVar) { + if (unaryOp->isIncrementOp()) { + stepHold = true; + step = 1; + } + } + } + } else if (compAssign) { + auto varDecl = dyn_cast(compAssign->getLHS()); + if (varDecl) { + if (dyn_cast(varDecl->getDecl()) == indVar) { + // Floating point iterator not supported yet + auto isStepConstant = + dyn_cast(compAssign->getRHS()); + if (isStepConstant) { + if (isStepConstant->getValue().isStrictlyPositive()) { + stepHold = true; + step = isStepConstant->getValue().getZExtValue(); + } + } + } + } + } + } + if (!stepHold) + return false; + + // Condition check - has condition & inequality check + bool condHold = false; + bool equality = false; + auto isIterLHS = 0; + auto cond = Stmt->getCond(); + BinaryOperator *binaryCond; + clang::Stmt *hbExpr; + if (cond) { + binaryCond = dyn_cast(cond); + auto opcode = binaryCond->getOpcode(); + // loop iterator should be seen in the loop condition + isIterLHS = (opcode == clang::BinaryOperator::Opcode::BO_LT || + opcode == clang::BinaryOperator::Opcode::BO_LE) + ? 1 + : (opcode == clang::BinaryOperator::Opcode::BO_GT || + opcode == clang::BinaryOperator::Opcode::BO_GE) + ? 2 + : 0; + if (binaryCond && isIterLHS) { + equality = (opcode == clang::BinaryOperator::Opcode::BO_GT || + opcode == clang::BinaryOperator::Opcode::BO_LT) + ? false + : true; + clang::Stmt *iter; + if (isIterLHS > 1) { + iter = binaryCond->getRHS(); + hbExpr = binaryCond->getLHS(); + } else { + iter = binaryCond->getLHS(); + hbExpr = binaryCond->getRHS(); + } + if (iter->getStmtClass() == Stmt::StmtClass::ImplicitCastExprClass && + dyn_cast( + dyn_cast(iter)->getSubExpr()) + ->getDecl() == indVar) + condHold = true; + } + } + if (!condHold) + return false; + + // // Code generation + // // Temporary store the variables in the current region + // // TODO: This inefficient, to be optimized + // str_vec varTemp; + // for (auto &var : vars) + // varTemp.push_back(var->name); + + // SCF header generation - try to extract the four constraints + // indvar - loop iterator + // lb - low bound of the loop + // hb - high bound of the loop + // step - step of the loop + std::string buffer; + auto lb = exprGen(&*(indDecl->getInit()), -1, &buffer); + auto indID = (initDecl) ? addVar(indVar, getIndexType()) + : castToIndex(getVarID(indVar), exprBuff); + auto hb = exprGen(&*hbExpr, -1, &buffer); + auto indType = getVarType(indID); + auto stepVar = constIntGen(step, indType, &buffer); + + // Generate yield variable list, i.e. get the set of modified variables + int_vec yields; + auto loopBody = Stmt->getBody(); + getYieldVars(loopBody, &yields); + std::string argsBuff; + if (yields.size() > 0) { + for (auto &y : yields) { + auto name = getVarName(y); + argsBuff += incrVarName(y) + " = " + name + ", "; + } + argsBuff.pop_back(); + argsBuff.pop_back(); + } + + // Generate loop body + // TODO: loop iterator updated within the loop body? + std::string body; + stmtGen(loopBody, &body); + + // Generate yield op + auto toYield = (yields.size() > 0); + std::string header, yieldTypes; + str_vec yieldsVars; + if (toYield) { + for (auto &y : yields) { + yieldsVars.push_back(getVarName(y)); + yieldTypes += typePrint(y) + ", "; + } + yieldTypes.pop_back(); + yieldTypes.pop_back(); + body += "scf.yield "; + for (auto y : yields) + body += getVarName(y) + ", "; + body.pop_back(); + body.pop_back(); + body += " : " + yieldTypes + "\n"; + for (unsigned long i = 0; i < yields.size(); i++) + setVarName(yields[i], yieldsVars[i]); + for (auto y : yields) + header += incrVarName(y) + ", "; + header.pop_back(); + header.pop_back(); + header += " = "; + } + auto lbIdx = castToIndex(lb, &buffer); + auto hbIdx = castToIndex(hb, &buffer); + auto stepIdx = castToIndex(stepVar, &buffer); + *exprBuff += buffer; + header += "scf.for " + getVarName(indID) + " = " + getVarName(lbIdx) + + " to " + getVarName(hbIdx) + " step " + getVarName(stepIdx) + + "\n"; + header += "iter_args(" + argsBuff + ") -> (" + yieldTypes + ") {\n"; + *exprBuff += header + body + "}\n"; + return true; + } + + // Generate generic for loops without jumping as scf.while loops. + // init is outside of the while loop, condition is the same, and body followed + // by inc is the body of a while loop + void generalForGen(ForStmt *Stmt, std::string *exprBuff) { + auto init = Stmt->getInit(); + if (init) + stmtGen(init, exprBuff); + auto cond = Stmt->getCond(); + assert(cond); + auto inc = Stmt->getInc(); + + // Generate yield variable list, i.e. get the set of modified variables + int_vec yields; + auto loopBody = Stmt->getBody(); + getYieldVars(loopBody, &yields); + getYieldVars(inc, &yields); + std::string argsBuff, conditionArgs, yieldTypes; + if (yields.size() > 0) { + for (auto &y : yields) { + auto name = getVarName(y); + auto newName = incrVarName(y); + yieldTypes += typePrint(y) + ", "; + argsBuff += newName + " = " + name + ", "; + conditionArgs += newName + ", "; + } + argsBuff.pop_back(); + argsBuff.pop_back(); + conditionArgs.pop_back(); + conditionArgs.pop_back(); + yieldTypes.pop_back(); + yieldTypes.pop_back(); + } + auto header = "scf.while (" + argsBuff + ") : (" + yieldTypes + ") -> (" + + yieldTypes + ") {\n"; + auto condition = exprGen(cond, -1, &header); + std::string bodyArgs; + if (yields.size() > 0) { + for (auto &y : yields) + bodyArgs += incrVarName(y) + " : " + typePrint(y) + ", "; + bodyArgs.pop_back(); + bodyArgs.pop_back(); + } + header += "scf.condition(" + getVarName(condition) + ") " + conditionArgs + + " : " + yieldTypes + "\n} do {\n^bb0(" + bodyArgs + "): \n"; + + std::string body; + stmtGen(loopBody, &body); + stmtGen(inc, &body); + + // Generate yield op + auto toYield = (yields.size() > 0); + std::string results; + str_vec yieldsVars; + if (toYield) { + for (auto y : yields) + yieldsVars.push_back(getVarName(y)); + + body += "scf.yield "; + for (auto y : yields) + body += getVarName(y) + ", "; + body.pop_back(); + body.pop_back(); + body += " : " + yieldTypes + "\n"; + for (unsigned long i = 0; i < yields.size(); i++) + setVarName(yields[i], yieldsVars[i]); + for (auto y : yields) + results += incrVarName(y) + ", "; + results.pop_back(); + results.pop_back(); + results += " = "; + } + *exprBuff += results + header + body + "}\n"; + } + + // Generate for loops + void forGen(ForStmt *Stmt, std::string *exprBuff) { + if (!trySCFForGen(Stmt, exprBuff)) { + if (!hasJump(Stmt)) { + // Generate irregular for loops into while loops + generalForGen(Stmt, exprBuff); + } else + error("Jumps in loops unsupported for now."); + } + } + + // Generate variable declarations + void declStmtGen(DeclStmt *Stmt, std::string *exprBuff) { + for (auto decl : Stmt->decls()) { + if (VarDecl *temp = dyn_cast(decl)) { + auto var = addVar(&*temp); + if (temp->getType()->isArrayType()) { + assert(!temp->hasInit() && "Initialised arrays not supported"); + *exprBuff += + incrVarName(var) + " = alloc() : " + typePrint(var) + "\n"; + } else if (temp->hasInit()) + exprGen(&*(temp->getInit()), var, exprBuff); + else + // TODO: declaring unintialised varaibles are initialised to 0 in + // MLIR (will be processed by an additional MLIR pass) + *exprBuff += + incrVarName(var) + " = constant 0 : " + typePrint(var) + "\n"; + + } else + error("Undefined declaration found in stmtGen."); + } + } + + // Generate compound assignment + void compoundAssignStmtGen(CompoundAssignOperator *compAssign, + std::string *stmtBuff) { + auto rhs = exprGen(&*(compAssign->getRHS()), -1, stmtBuff); + auto rhsType = getVarType(rhs); + auto lhs = compAssign->getLHS(); + auto lhsAsArray = dyn_cast(lhs); + + if (lhsAsArray) { // Is an array + int_vec idxList; + getArrayIdxExpr(lhsAsArray, &idxList, stmtBuff); + auto array = getArrayLabel(lhsAsArray); + auto lhsVar = addLoad(array, -1, stmtBuff, &idxList); + auto lhsVarType = getVarType(lhsVar); + assert(lhsVarType == rhsType && + "casting for compound assignment not supported"); + auto result = addVar(nullptr, lhsVarType); + *stmtBuff += getVarName(result) + " = " + + binaryOpPrint(compAssign, rhsType) + " " + + getVarName(lhsVar) + ", " + getVarName(rhs) + " : " + + typePrint(rhsType) + "\n"; + *stmtBuff += addStore(result, array, &idxList); + } else { + auto lhsVar = dyn_cast(lhs); + assert(lhsVar && + "Unsupproted destination format for compound assignment"); + auto result = getVarID(lhsVar->getDecl()); + assert(getVarType(result) == rhsType && + "casting for compound assignment not supported"); + auto buffer = " = " + binaryOpPrint(compAssign, rhsType) + " " + + getVarName(result) + ", " + getVarName(rhs) + " : " + + typePrint(rhsType) + "\n"; + *stmtBuff += incrVarName(result) + buffer; + } + } + + // Generate general statements + void stmtGen(Stmt *Stmt, std::string *stmtBuff) { + if (debug) + std::cout << "Found stmt: " << Stmt->getStmtClassName() << "\n"; + std::string buffer = ""; + AsmStmt::StmtClass Type = Stmt->getStmtClass(); + int var; + auto children = Stmt->children(); + assert(!unaryUpdateList.size()); + switch (Type) { + case Stmt::StmtClass::BinaryOperatorClass: + binaryOperatorGen(Stmt, -1, stmtBuff); + break; + case Stmt::StmtClass::DeclStmtClass: + declStmtGen(dyn_cast(Stmt), stmtBuff); + break; + case Stmt::StmtClass::CompoundStmtClass: + for (auto &st : children) + stmtGen(st, stmtBuff); + break; + case Stmt::StmtClass::IfStmtClass: + ifGen(dyn_cast(Stmt), stmtBuff); + break; + case Stmt::StmtClass::ReturnStmtClass: + assert(std::distance(std::cbegin(children), std::cend(children)) == 1); + var = exprGen(*(Stmt->child_begin()), -1, stmtBuff); + *stmtBuff += std::string("return " + getVarName(var) + " : " + + typePrint(var) + "\n"); + break; + case Stmt::StmtClass::UnaryOperatorClass: + unaryExprGen(Stmt, -1, stmtBuff); + break; + case Stmt::StmtClass::ForStmtClass: + forGen(dyn_cast(Stmt), stmtBuff); + break; + case Stmt::StmtClass::CompoundAssignOperatorClass: + compoundAssignStmtGen(dyn_cast(Stmt), stmtBuff); + break; + default: + error("Undefined statement class in stmtGen: " + + std::string(Stmt->getStmtClassName())); + } + if (unaryUpdateList.size() > 0) { + for (auto &i : unaryUpdateList) + incrVarName(i); + unaryUpdateList.clear(); + } + } + + // Generate functions + void funcGen(FunctionDecl *func) { + std::string funcBody; + mlirOut << "func @" << func->getNameInfo().getName().getAsString() << "("; + // Print function arguments + if (!func->param_empty()) { + std::string buffer; + for (auto ¶meter : func->parameters()) { + auto type = parameter->getOriginalType(); + auto argName = getVarName(addVar(parameter, type)); + if (type->isConstantArrayType()) + buffer += argName + " : " + typePrint(type) + ", "; + else if (!type->isPointerType() && !type->isArrayType()) { + buffer += getVarName(addVar(parameter, type)) + " : " + + typePrint(type) + ", "; + } else + error("Undefined argument type for function " + + func->getNameInfo().getName().getAsString()); + } + mlirOut << buffer.substr(0, buffer.length() - 2) << ")"; + } else + mlirOut << ")"; + // Print function return type + bool appendReturn; + if (!func->isNoReturn()) { + auto returnType = typePrint((func->getDeclaredReturnType())); + appendReturn = (returnType == "") ? true : false; + mlirOut << " -> (" << returnType << ")"; + } else + appendReturn = true; + mlirOut << "{\n"; + stmtGen(func->getBody(), &funcBody); + mlirOut << funcBody; + // Add a return op if the return type is void + if (appendReturn) + mlirOut << "return\n"; + mlirOut << "}\n"; + } +}; + +// Consumer +class MLIRGeneratorConsumer : public clang::ASTConsumer { +public: + explicit MLIRGeneratorConsumer(ASTContext *Context) : Visitor(Context) {} + + virtual void HandleTranslationUnit(ASTContext &Context) override { + Visitor.TraverseDecl(Context.getTranslationUnitDecl()); + } + +private: + MLIRGeneratorVisitor Visitor; +}; + +// Action +class MLIRGeneratorAction : public clang::ASTFrontendAction { +public: + virtual std::unique_ptr + CreateASTConsumer(clang::CompilerInstance &Compiler, + llvm::StringRef InFile) override { + return std::unique_ptr( + new MLIRGeneratorConsumer(&Compiler.getASTContext())); + } +}; + +// Print help infomation before exit +void exit(void) { + std::cout << "Usage: scalehls-clang [options] " << std::endl; + std::cout << "--debug\t\t\t\t- Enable debug mode to print debug information" + << std::endl; + std::cout << "-o \t\t\t- Specify the file name of the output MLIR" + << std::endl; +} + +// Tool configurations +// * --help display help infomation +// * --debug print additional debugging info +bool option(std::string op) { + if (op == "--help") + return false; + else if (op == "--debug") { + debug = true; + return true; + } else { + std::cout << "\033[0;31mUndefined option " << op << "\033[0m" << std::endl; + return false; + } +} + +int main(int argc, char **argv) { + if (argc < 2) { + exit(); + return -1; + } + + // Detect configurations and input file + std::string fileName; + int i; + for (i = 1; i < argc; i++) { + if (argv[i][0] == '-' && argv[i][1] == '-') { + if (!option(argv[i])) { + exit(); + return -1; + } + } else if (std::string(argv[i]) == "-o" && i < argc - 1) { + ++i; + output = argv[i]; + } else if (fileName.empty()) + fileName = argv[i]; + else { + std::cout << "\033[0;31mError: Unrecognised command.\033[0m\n"; + exit(); + return -1; + } + } + std::ifstream fin(fileName); + if (!fin.is_open()) { + std::cout << "\033[0;31mError: Cannot find file " << fileName << "\033[0m" + << std::endl; + return -1; + } + + // Process + std::stringstream code; + code << fin.rdbuf(); + if (debug) { + std::cout << "C input: " << std::endl; + std::cout << code.str() << std::endl; + } + mlirOut.open(output); + if (debug) + std::cout << "Debug log: " << std::endl; + clang::tooling::runToolOnCode(std::make_unique(), + code.str()); + mlirOut.close(); + std::cout << "Translation finished." << std::endl; +} From b46e8715329f674f28bf6c4935c0cabe2863f30d Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 30 Apr 2021 23:25:37 +0100 Subject: [PATCH 6/8] Move clang config to top --- tools/scalehls-clang/CMakeLists.txt | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tools/scalehls-clang/CMakeLists.txt b/tools/scalehls-clang/CMakeLists.txt index 634bbb6..6365003 100644 --- a/tools/scalehls-clang/CMakeLists.txt +++ b/tools/scalehls-clang/CMakeLists.txt @@ -1,11 +1,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -#include_directories(../../llvm/clang/include/) -#include_directories(../../llvm/build/tools/clang/include) - -#find_package(Clang REQUIRED) - add_llvm_tool(scalehls-clang scalehls-clang.cpp ) From 9ac014900b7b2ef8bcfbf6b97886a84b7b625b46 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 30 Apr 2021 23:54:57 +0100 Subject: [PATCH 7/8] format --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index bc6353b..bfd061b 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ $ git submodule update ``` ### 1. Install LLVM and MLIR -To build LLVM and MLIR, run: +This step assumes this repository is cloned to `$SCALEHLS_DIR`. To build LLVM and MLIR, run: ```sh $ mkdir $SCALEHLS_DIR/llvm/build $ cd $SCALEHLS_DIR/llvm/build @@ -28,7 +28,7 @@ $ ninja check-mlir ``` ### 2. Install ScaleHLS -This step assumes this repository is cloned to `$SCALEHLS_DIR`. To build and launch the tests, run: +To build and launch the tests, run: ```sh $ mkdir $SCALEHLS_DIR/build $ cd $SCALEHLS_DIR/build From f907c9810ace43c838d65a15efc262e54d85ac50 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Sat, 1 May 2021 00:37:27 -0500 Subject: [PATCH 8/8] [scalehls-clang] add syrk test case; update using llvm cl tools --- README.md | 1 - test/lit.cfg.py | 5 +- test/scalehls-clang/syrk.c | 13 +++ tools/scalehls-clang/CMakeLists.txt | 2 + tools/scalehls-clang/scalehls-clang.cpp | 136 ++++++++++-------------- 5 files changed, 74 insertions(+), 83 deletions(-) create mode 100644 test/scalehls-clang/syrk.c diff --git a/README.md b/README.md index bfd061b..8b66b9e 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,6 @@ This project aims to create a framework that ultimately converts an algorithm wr ## Quick Start ### 0. Download ScaleHLS and LLVM - ``` $ git clone git@github.com:hanchenye/scalehls.git $ cd scalehls diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 2bd2fd8..f8e6505 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -21,7 +21,7 @@ config.name = 'SCALEHLS' config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir', '.ini'] +config.suffixes = ['.mlir', '.ini', '.c'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) @@ -51,7 +51,8 @@ config.test_exec_root = os.path.join(config.scalehls_obj_root, 'test') # Tweak the PATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) -tool_dirs = [config.scalehls_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir] +tool_dirs = [config.scalehls_tools_dir, + config.mlir_tools_dir, config.llvm_tools_dir] tools = [ 'scalehls-clang', 'scalehls-opt', diff --git a/test/scalehls-clang/syrk.c b/test/scalehls-clang/syrk.c new file mode 100644 index 0000000..b7994de --- /dev/null +++ b/test/scalehls-clang/syrk.c @@ -0,0 +1,13 @@ +// RUN: scalehls-clang %s | FileCheck %s + +// CHECK: func @syrk( +void syrk(float alpha, float beta, float C[32][32], float A[32][32]) { + for (int i = 0; i < 32; i++) { + for (int j = 0; j <= i; j++) { + C[i][j] *= beta; + for (int k = 0; k < 32; k++) { + C[i][j] += alpha * A[i][k] * A[j][k]; + } + } + } +} diff --git a/tools/scalehls-clang/CMakeLists.txt b/tools/scalehls-clang/CMakeLists.txt index 6365003..9914dc5 100644 --- a/tools/scalehls-clang/CMakeLists.txt +++ b/tools/scalehls-clang/CMakeLists.txt @@ -9,6 +9,8 @@ llvm_update_compile_flags(scalehls-clang) target_link_libraries(scalehls-clang PRIVATE + ${dialect_libs} + clangFrontend clangTooling clangBasic diff --git a/tools/scalehls-clang/scalehls-clang.cpp b/tools/scalehls-clang/scalehls-clang.cpp index 9b7d85a..d2dc631 100644 --- a/tools/scalehls-clang/scalehls-clang.cpp +++ b/tools/scalehls-clang/scalehls-clang.cpp @@ -4,6 +4,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Support/FileUtilities.h" #include "clang/AST/ASTConsumer.h" #include "clang/AST/Decl.h" #include "clang/AST/Expr.h" @@ -16,6 +17,9 @@ #include "clang/Tooling/Tooling.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ToolOutputFile.h" #include #include @@ -34,15 +38,14 @@ typedef std::vector str_vec; // This tool emits the MLIR code from C code by traversing the Clang AST. // ------------------------------------------- -// Default mode configuration -// Disable full flow check - enabling debug mode will lower the MLIR through -// SCF, standard, handshake till HPX Ops. -bool debug = false; -// Output file name -std::string output = "output.mlir"; +static llvm::cl::opt + inputFilename(llvm::cl::Positional, llvm::cl::desc("Input filename"), + llvm::cl::init("-")); + +static llvm::cl::opt + outputFilename("o", llvm::cl::desc("Output filename"), + llvm::cl::value_desc("filename"), llvm::cl::init("-")); -// Output file stream -std::ofstream mlirOut; // Number of variables in the MLIR code including intermediate variables int varCnt = 0; // Number of intermediate variables @@ -52,19 +55,20 @@ int_vec unaryUpdateList; // Print error messages auto error = [](auto str) { - std::cout << "\033[0;31mError: " << str << "\033[0m" << std::endl; + llvm::errs() << "\033[0;31mError: " << str << "\033[0m\n"; assert(0); }; // Print highlight messages, like warning auto message = [](auto str) { - std::cout << "\033[0;33mWarning: " << str << "\033[0m" << std::endl; + llvm::errs() << "\033[0;33mWarning: " << str << "\033[0m\n"; }; // Translation unit class MLIRGeneratorVisitor : public RecursiveASTVisitor { public: - explicit MLIRGeneratorVisitor(ASTContext *Context) : Context(Context) {} + explicit MLIRGeneratorVisitor(ASTContext *Context, raw_ostream &mlirOut) + : Context(Context), mlirOut(mlirOut) {} virtual bool VisitFunctionDecl(FunctionDecl *func) { // Currently ignore the function prototype @@ -85,6 +89,8 @@ private: typedef std::vector var_vec; ASTContext *Context; + raw_ostream &mlirOut; + // Variable list var_vec vars; @@ -466,7 +472,8 @@ private: args.pop_back(); args.pop_back(); args += "] : " + typePrint(arrayType) + "\n"; - buff += "store " + getVarName(src) + ", " + getVarName(dst) + "[" + args; + buff += + "memref.store " + getVarName(src) + ", " + getVarName(dst) + "[" + args; return buff; } @@ -486,7 +493,7 @@ private: res = resultVar; resName = incrVarName(res); } - std::string args = resName + " = load " + getVarName(array) + "["; + std::string args = resName + " = memref.load " + getVarName(array) + "["; for (auto &i : *index) args += getVarName(castToIndex(i, buff)) + ", "; args.pop_back(); @@ -706,8 +713,8 @@ private: // Generate a generic expression int exprGen(clang::Stmt *Stmt, int resultVar, std::string *exprBuff) { - if (debug) - std::cout << "Found stmt: " << Stmt->getStmtClassName() << "\n"; + // llvm::dbgs() << "Found stmt: " << Stmt->getStmtClassName() << "\n"; + auto Type = Stmt->getStmtClass(); auto result = -1; switch (Type) { @@ -990,10 +997,10 @@ private: isIterLHS = (opcode == clang::BinaryOperator::Opcode::BO_LT || opcode == clang::BinaryOperator::Opcode::BO_LE) ? 1 - : (opcode == clang::BinaryOperator::Opcode::BO_GT || - opcode == clang::BinaryOperator::Opcode::BO_GE) - ? 2 - : 0; + : (opcode == clang::BinaryOperator::Opcode::BO_GT || + opcode == clang::BinaryOperator::Opcode::BO_GE) + ? 2 + : 0; if (binaryCond && isIterLHS) { equality = (opcode == clang::BinaryOperator::Opcode::BO_GT || opcode == clang::BinaryOperator::Opcode::BO_LT) @@ -1238,8 +1245,8 @@ private: // Generate general statements void stmtGen(Stmt *Stmt, std::string *stmtBuff) { - if (debug) - std::cout << "Found stmt: " << Stmt->getStmtClassName() << "\n"; + // llvm::dbgs() << "Found stmt: " << Stmt->getStmtClassName() << "\n"; + std::string buffer = ""; AsmStmt::StmtClass Type = Stmt->getStmtClass(); int var; @@ -1328,7 +1335,8 @@ private: // Consumer class MLIRGeneratorConsumer : public clang::ASTConsumer { public: - explicit MLIRGeneratorConsumer(ASTContext *Context) : Visitor(Context) {} + explicit MLIRGeneratorConsumer(ASTContext *Context, raw_ostream &mlirOut) + : Visitor(Context, mlirOut) {} virtual void HandleTranslationUnit(ASTContext &Context) override { Visitor.TraverseDecl(Context.getTranslationUnitDecl()); @@ -1341,12 +1349,17 @@ private: // Action class MLIRGeneratorAction : public clang::ASTFrontendAction { public: + explicit MLIRGeneratorAction(raw_ostream &mlirOut) : mlirOut(mlirOut) {} + virtual std::unique_ptr CreateASTConsumer(clang::CompilerInstance &Compiler, llvm::StringRef InFile) override { return std::unique_ptr( - new MLIRGeneratorConsumer(&Compiler.getASTContext())); + new MLIRGeneratorConsumer(&Compiler.getASTContext(), mlirOut)); } + +private: + raw_ostream &mlirOut; }; // Print help infomation before exit @@ -1358,66 +1371,29 @@ void exit(void) { << std::endl; } -// Tool configurations -// * --help display help infomation -// * --debug print additional debugging info -bool option(std::string op) { - if (op == "--help") - return false; - else if (op == "--debug") { - debug = true; - return true; - } else { - std::cout << "\033[0;31mUndefined option " << op << "\033[0m" << std::endl; - return false; - } -} - int main(int argc, char **argv) { - if (argc < 2) { - exit(); - return -1; + if (!llvm::cl::ParseCommandLineOptions(argc, argv, + "HLS C front-end for MLIR\n")) + exit(1); + + // Set up the input and output file. + std::string errorMessage; + auto file = mlir::openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + exit(1); } - // Detect configurations and input file - std::string fileName; - int i; - for (i = 1; i < argc; i++) { - if (argv[i][0] == '-' && argv[i][1] == '-') { - if (!option(argv[i])) { - exit(); - return -1; - } - } else if (std::string(argv[i]) == "-o" && i < argc - 1) { - ++i; - output = argv[i]; - } else if (fileName.empty()) - fileName = argv[i]; - else { - std::cout << "\033[0;31mError: Unrecognised command.\033[0m\n"; - exit(); - return -1; - } - } - std::ifstream fin(fileName); - if (!fin.is_open()) { - std::cout << "\033[0;31mError: Cannot find file " << fileName << "\033[0m" - << std::endl; - return -1; + auto output = mlir::openOutputFile(outputFilename, &errorMessage); + if (!output) { + llvm::errs() << errorMessage << "\n"; + exit(1); } - // Process - std::stringstream code; - code << fin.rdbuf(); - if (debug) { - std::cout << "C input: " << std::endl; - std::cout << code.str() << std::endl; - } - mlirOut.open(output); - if (debug) - std::cout << "Debug log: " << std::endl; - clang::tooling::runToolOnCode(std::make_unique(), - code.str()); - mlirOut.close(); - std::cout << "Translation finished." << std::endl; + // Process the parsing. + clang::tooling::runToolOnCode( + std::make_unique(output->os()), file->getBuffer()); + + output->keep(); + return 0; }