diff --git a/include/polygeist/Passes/CMakeLists.txt b/include/polygeist/Passes/CMakeLists.txt index acd0afc..da545fc 100644 --- a/include/polygeist/Passes/CMakeLists.txt +++ b/include/polygeist/Passes/CMakeLists.txt @@ -2,4 +2,4 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name polygeist) add_public_tablegen_target(MLIRPolygeistPassIncGen) -add_mlir_doc(Passes PolygeistPasses ./ -gen-pass-doc) \ No newline at end of file +add_mlir_doc(Passes PolygeistPasses ./ -gen-pass-doc) diff --git a/lib/polygeist/CMakeLists.txt b/lib/polygeist/CMakeLists.txt index 4023d1c..4c850e0 100644 --- a/lib/polygeist/CMakeLists.txt +++ b/lib/polygeist/CMakeLists.txt @@ -11,4 +11,4 @@ MLIRPolygeistOpsIncGen LINK_LIBS PUBLIC MLIRIR ) -add_subdirectory(Passes) \ No newline at end of file +add_subdirectory(Passes) diff --git a/mlir-clang/Lib/pragmaHandler.cc b/mlir-clang/Lib/pragmaHandler.cc index 95da5b9..74af47e 100644 --- a/mlir-clang/Lib/pragmaHandler.cc +++ b/mlir-clang/Lib/pragmaHandler.cc @@ -6,6 +6,7 @@ #include "clang/Lex/LexDiagnostic.h" #include "clang/Lex/LiteralSupport.h" #include "clang/Lex/Preprocessor.h" +#include "clang/Parse/ParseDiagnostic.h" #include "clang/Sema/Sema.h" #include "llvm/ADT/DenseMap.h" @@ -31,34 +32,70 @@ public: /// the function decl that lower_to is attached to with that MLIR function /// symbol in the class-referenced dictionary. /// - /// TODO: Handle assertions properly. void HandlePragma(Preprocessor &PP, PragmaIntroducer Introducer, Token &PragmaTok) { Token Tok; PP.Lex(Tok); // lparen - assert(Tok.is(tok::l_paren) && "lower_to should start with '('."); + if (Tok.isNot(tok::l_paren)) { + PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_lparen) + << "lower_to"; + return; + } - Token FuncIdTok; // function identifier - PP.Lex(FuncIdTok); - assert(FuncIdTok.is(tok::identifier) && - "The first argument of lower_to should be an identifier."); + Token PrevTok = Tok; + llvm::StringRef FuncId = llvm::StringRef(); + llvm::StringRef SymbolName = llvm::StringRef(); + while (Tok.isNot(tok::eod)) { + Token CurrentTok; + PP.Lex(CurrentTok); - llvm::StringRef FuncId = FuncIdTok.getIdentifierInfo()->getName(); + // rparen. + if (PrevTok.is(tok::string_literal)) { + if (CurrentTok.isNot(tok::r_paren)) { + PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_rparen) + << "lower_to"; + return; + } else { + break; + } + } - PP.Lex(Tok); // comma - assert(Tok.is(tok::comma) && "The first and second argument of lower_to " - "should be separated by a comma."); + // function identifier. + if (PrevTok.is(tok::l_paren)) { + if (CurrentTok.isNot(tok::identifier)) { + PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_identifier) + << "lower_to"; + return; + } else { + FuncId = CurrentTok.getIdentifierInfo()->getName(); + } + } - // Parse the string literal argument, which is the MLIR function symbol. - SmallVector SymbolToks; - Token SymbolTok; - PP.Lex(SymbolTok); - assert(SymbolTok.is(tok::string_literal) && - "The second argument of lower_to should be a string literal."); - SymbolToks.push_back(SymbolTok); - StringRef SymbolName = StringLiteralParser(SymbolToks, PP).GetString(); - PP.Lex(Tok); // rparen - assert(Tok.is(tok::r_paren) && "lower_to should end with '('."); + // comma. + if (PrevTok.is(tok::identifier)) { + if (CurrentTok.isNot(tok::comma)) { + PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_punc) + << "lower_to"; + return; + } + } + + // string literal, which is the MLIR function symbol. + if (PrevTok.is(tok::comma)) { + if (CurrentTok.isNot(tok::string_literal)) { + PP.Diag(CurrentTok.getLocation(), + diag::warn_pragma_expected_section_name) + << "lower to"; + return; + } else { + SmallVector SymbolToks; + SymbolToks.push_back(CurrentTok); + SymbolName = StringLiteralParser(SymbolToks, PP).GetString(); + } + } + + PrevTok = CurrentTok; + } // Link SymbolName with the function. auto result = Info.SymbolTable.try_emplace(FuncId, SymbolName); diff --git a/mlir-clang/Lib/utils.cc b/mlir-clang/Lib/utils.cc index 28aa0fc..2c1f647 100644 --- a/mlir-clang/Lib/utils.cc +++ b/mlir-clang/Lib/utils.cc @@ -9,9 +9,25 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" + using namespace mlir; using namespace llvm; +Operation *mlirclang::buildLinalgOp(const AbstractOperation *op, + ArrayRef operands, + ArrayRef opResultTypes, + OpBuilder &b) { + StringRef name = op->name; + if (name.compare("linalg.copy") == 0) { + return b.create(b.getUnknownLoc(), operands[1], + operands[0]); + } else { + llvm::report_fatal_error(llvm::Twine("builder not supported for: ") + name); + return nullptr; + } +} + Operation *mlirclang::replaceFuncByOperation(FuncOp f, StringRef opName, ArrayRef operands, OpBuilder &b) { @@ -21,6 +37,9 @@ Operation *mlirclang::replaceFuncByOperation(FuncOp f, StringRef opName, const AbstractOperation *op = AbstractOperation::lookup(opName, ctx); + if (opName.startswith("linalg")) + return buildLinalgOp(op, operands, f.getCallableResults(), b); + // NOTE: The attributes of the provided FuncOp is ignored. OperationState opState(b.getUnknownLoc(), op->name, ValueRange(operands), f.getCallableResults(), {}); diff --git a/mlir-clang/Lib/utils.h b/mlir-clang/Lib/utils.h index a0d3904..9f2334f 100644 --- a/mlir-clang/Lib/utils.h +++ b/mlir-clang/Lib/utils.h @@ -10,6 +10,8 @@ class Operation; class FuncOp; class Value; class OpBuilder; +class AbstractOperation; +class Type; } // namespace mlir namespace llvm { @@ -26,6 +28,10 @@ namespace mlirclang { mlir::Operation *replaceFuncByOperation(mlir::FuncOp f, llvm::StringRef opName, llvm::ArrayRef operands, mlir::OpBuilder &b); +mlir::Operation *buildLinalgOp(const mlir::AbstractOperation *op, + llvm::ArrayRef operands, + llvm::ArrayRef results, + mlir::OpBuilder &b); } // namespace mlirclang diff --git a/mlir-clang/Test/Verification/lower-to-linalg-op.c b/mlir-clang/Test/Verification/lower-to-linalg-op.c new file mode 100644 index 0000000..dd699c7 --- /dev/null +++ b/mlir-clang/Test/Verification/lower-to-linalg-op.c @@ -0,0 +1,16 @@ +// RUN: mlir-clang %s | FileCheck %s + +#pragma lower_to(copy_op, "linalg.copy") +void copy_op(int a[3][3], int b[3][3]) { + for (int i = 0; i < 3; i++) + for (int j = 0; j < 3; j++) + b[i][j] = a[i][j]; +} + +int main() { + int a[3][3]; + int b[3][3]; + // CHECK: linalg.copy({{.*}}, {{.*}}) : memref<3x3xi32>, memref<3x3xi32> + copy_op(a, b); + return 0; +} diff --git a/mlir-clang/mlir-clang.cc b/mlir-clang/mlir-clang.cc index 9a54433..a3c1e5b 100644 --- a/mlir-clang/mlir-clang.cc +++ b/mlir-clang/mlir-clang.cc @@ -4,6 +4,7 @@ #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SCF/SCF.h" @@ -129,6 +130,7 @@ int main(int argc, char **argv) { context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); + context.getOrLoadDialect(); context.getOrLoadDialect(); // MLIRContext context;