From 3e7b12a2c419a2337f2606a02bd99aae080f8929 Mon Sep 17 00:00:00 2001 From: Vincent Zhao Date: Fri, 14 May 2021 21:28:26 +0100 Subject: [PATCH] lower_to pragma handler --- llvm-project | 2 +- mlir-clang/Lib/PragmaLowerToHandler.cc | 81 ++++++++++++++++++++++++++ mlir-clang/Lib/PragmaLowerToHandler.h | 17 ++++++ mlir-clang/Lib/clang-mlir.cc | 8 +++ 4 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 mlir-clang/Lib/PragmaLowerToHandler.cc create mode 100644 mlir-clang/Lib/PragmaLowerToHandler.h diff --git a/llvm-project b/llvm-project index 85bac9d..d3e14fa 160000 --- a/llvm-project +++ b/llvm-project @@ -1 +1 @@ -Subproject commit 85bac9d7f93419ad73cee42a87465d50286a46fe +Subproject commit d3e14fafc69a07e3dab9ddb91f1d810bb5f8d7a0 diff --git a/mlir-clang/Lib/PragmaLowerToHandler.cc b/mlir-clang/Lib/PragmaLowerToHandler.cc new file mode 100644 index 0000000..dcec0f7 --- /dev/null +++ b/mlir-clang/Lib/PragmaLowerToHandler.cc @@ -0,0 +1,81 @@ +#include "PragmaLowerToHandler.h" + +#include "clang/AST/AST.h" +#include "clang/AST/ASTConsumer.h" +#include "clang/AST/Attr.h" +#include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Lex/LexDiagnostic.h" +#include "clang/Lex/Preprocessor.h" +#include "clang/Sema/Sema.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +using namespace clang; +using namespace llvm; + +namespace { + +/// Handles the #pragma lower_to(, "") +/// directive. +class PragmaLowerToHandler : public PragmaHandler { + Sema &Actions; + + LowerToInfo &Info; + +public: + PragmaLowerToHandler(Sema &Actions, LowerToInfo &Info) + : PragmaHandler("lower_to"), Actions(Actions), Info(Info) {} + + /// The pragma handler will extract the single argument to the lower_to(...) + /// pragma definition, which is the target MLIR function symbol, and relate + /// the function decl that lower_to is attached to with that MLIR function + /// symbol in the class-referenced dictionary. + /// + /// TODO: Handle assertions properly. + void HandlePragma(Preprocessor &PP, PragmaIntroducer Introducer, + Token &PragmaTok) { + Token Tok; + PP.Lex(Tok); // lparen + assert(Tok.is(tok::l_paren) && "lower_to should start with '('."); + + Token FuncIdTok; // function identifier + PP.Lex(FuncIdTok); + assert(FuncIdTok.is(tok::identifier) && + "The first argument of lower_to should be an identifier."); + + llvm::StringRef FuncId = FuncIdTok.getIdentifierInfo()->getName(); + + PP.Lex(Tok); // comma + assert(Tok.is(tok::comma) && "The first and second argument of lower_to " + "should be separated by a comma."); + + // Parse the string literal argument, which is the MLIR function symbol. + SmallVector SymbolToks; + Token SymbolTok; + PP.Lex(SymbolTok); + assert(SymbolTok.is(tok::string_literal) && + "The second argument of lower_to should be a string literal."); + SymbolToks.push_back(SymbolTok); + clang::StringLiteral *SymbolName = cast( + Actions.ActOnStringLiteral(SymbolToks).get()); + + PP.Lex(Tok); // rparen + assert(Tok.is(tok::r_paren) && "lower_to should end with '('."); + + // Link SymbolName with the function. + auto result = Info.SymbolTable.try_emplace(FuncId, SymbolName->getString()); + assert(result.second && + "Shouldn't define lower_to over the same func id more than once."); + } + +private: +}; + +} // namespace + +void addPragmaLowerToHandlers(Preprocessor &PP, Sema &Actions, + LowerToInfo <Info) { + PP.AddPragmaHandler(new PragmaLowerToHandler(Actions, LTInfo)); +} diff --git a/mlir-clang/Lib/PragmaLowerToHandler.h b/mlir-clang/Lib/PragmaLowerToHandler.h new file mode 100644 index 0000000..32ef7db --- /dev/null +++ b/mlir-clang/Lib/PragmaLowerToHandler.h @@ -0,0 +1,17 @@ +#ifndef MLIR_TOOLS_MLIRCLANG_LIB_PRAGMALOWERTOHANDLER_H +#define MLIR_TOOLS_MLIRCLANG_LIB_PRAGMALOWERTOHANDLER_H + +#include "clang/Lex/Preprocessor.h" +#include "clang/Sema/Sema.h" + +#include "llvm/ADT/DenseMap.h" + +/// POD holds information processed from the lower_to pragma. +struct LowerToInfo { + llvm::StringMap SymbolTable; +}; + +void addPragmaLowerToHandlers(clang::Preprocessor &PP, clang::Sema &Actions, + LowerToInfo <Info); + +#endif diff --git a/mlir-clang/Lib/clang-mlir.cc b/mlir-clang/Lib/clang-mlir.cc index 90b1e44..3a0fc10 100644 --- a/mlir-clang/Lib/clang-mlir.cc +++ b/mlir-clang/Lib/clang-mlir.cc @@ -2373,6 +2373,14 @@ ValueWithOffsets MLIRScanner::VisitCallExpr(clang::CallExpr *expr) { i++; } + if (LTInfo.SymbolTable.count(tocall.getName())) { + return ValueWithOffsets( + replaceFuncByOperation(tocall, LTInfo.SymbolTable[tocall.getName()], + args, builder) + ->getResult(0), + /*isReference=*/false); + } + bool isArrayReturn = false; if (!(expr->isLValue() || expr->isXValue())) Glob.getMLIRType(expr->getType(), &isArrayReturn);