Use LowerTo pragma to generate a linalg.copyOp

- Refactor pragma lower to handling
This commit is contained in:
lchelini 2021-07-13 10:11:50 +02:00 committed by lorenzo chelini
parent 3b54e6eea4
commit e9b551c1e7
7 changed files with 102 additions and 22 deletions

View File

@ -6,6 +6,7 @@
#include "clang/Lex/LexDiagnostic.h"
#include "clang/Lex/LiteralSupport.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Parse/ParseDiagnostic.h"
#include "clang/Sema/Sema.h"
#include "llvm/ADT/DenseMap.h"
@ -31,34 +32,70 @@ public:
/// the function decl that lower_to is attached to with that MLIR function
/// symbol in the class-referenced dictionary.
///
/// TODO: Handle assertions properly.
void HandlePragma(Preprocessor &PP, PragmaIntroducer Introducer,
Token &PragmaTok) {
Token Tok;
PP.Lex(Tok); // lparen
assert(Tok.is(tok::l_paren) && "lower_to should start with '('.");
if (Tok.isNot(tok::l_paren)) {
PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_lparen)
<< "lower_to";
return;
}
Token FuncIdTok; // function identifier
PP.Lex(FuncIdTok);
assert(FuncIdTok.is(tok::identifier) &&
"The first argument of lower_to should be an identifier.");
Token PrevTok = Tok;
llvm::StringRef FuncId = llvm::StringRef();
llvm::StringRef SymbolName = llvm::StringRef();
while (Tok.isNot(tok::eod)) {
Token CurrentTok;
PP.Lex(CurrentTok);
llvm::StringRef FuncId = FuncIdTok.getIdentifierInfo()->getName();
// rparen.
if (PrevTok.is(tok::string_literal)) {
if (CurrentTok.isNot(tok::r_paren)) {
PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_rparen)
<< "lower_to";
return;
} else {
break;
}
}
PP.Lex(Tok); // comma
assert(Tok.is(tok::comma) && "The first and second argument of lower_to "
"should be separated by a comma.");
// function identifier.
if (PrevTok.is(tok::l_paren)) {
if (CurrentTok.isNot(tok::identifier)) {
PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_identifier)
<< "lower_to";
return;
} else {
FuncId = CurrentTok.getIdentifierInfo()->getName();
}
}
// Parse the string literal argument, which is the MLIR function symbol.
// comma.
if (PrevTok.is(tok::identifier)) {
if (CurrentTok.isNot(tok::comma)) {
PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_punc)
<< "lower_to";
return;
}
}
// string literal, which is the MLIR function symbol.
if (PrevTok.is(tok::comma)) {
if (CurrentTok.isNot(tok::string_literal)) {
PP.Diag(CurrentTok.getLocation(),
diag::warn_pragma_expected_section_name)
<< "lower to";
return;
} else {
SmallVector<Token, 1> SymbolToks;
Token SymbolTok;
PP.Lex(SymbolTok);
assert(SymbolTok.is(tok::string_literal) &&
"The second argument of lower_to should be a string literal.");
SymbolToks.push_back(SymbolTok);
StringRef SymbolName = StringLiteralParser(SymbolToks, PP).GetString();
PP.Lex(Tok); // rparen
assert(Tok.is(tok::r_paren) && "lower_to should end with '('.");
SymbolToks.push_back(CurrentTok);
SymbolName = StringLiteralParser(SymbolToks, PP).GetString();
}
}
PrevTok = CurrentTok;
}
// Link SymbolName with the function.
auto result = Info.SymbolTable.try_emplace(FuncId, SymbolName);

View File

@ -9,9 +9,25 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
using namespace mlir;
using namespace llvm;
Operation *mlirclang::buildLinalgOp(const AbstractOperation *op,
ArrayRef<Value> operands,
ArrayRef<Type> opResultTypes,
OpBuilder &b) {
StringRef name = op->name;
if (name.compare("linalg.copy") == 0) {
return b.create<linalg::CopyOp>(b.getUnknownLoc(), operands[1],
operands[0]);
} else {
llvm::report_fatal_error(llvm::Twine("builder not supported for: ") + name);
return nullptr;
}
}
Operation *mlirclang::replaceFuncByOperation(FuncOp f, StringRef opName,
ArrayRef<Value> operands,
OpBuilder &b) {
@ -21,6 +37,9 @@ Operation *mlirclang::replaceFuncByOperation(FuncOp f, StringRef opName,
const AbstractOperation *op = AbstractOperation::lookup(opName, ctx);
if (opName.startswith("linalg"))
return buildLinalgOp(op, operands, f.getCallableResults(), b);
// NOTE: The attributes of the provided FuncOp is ignored.
OperationState opState(b.getUnknownLoc(), op->name, ValueRange(operands),
f.getCallableResults(), {});

View File

@ -10,6 +10,8 @@ class Operation;
class FuncOp;
class Value;
class OpBuilder;
class AbstractOperation;
class Type;
} // namespace mlir
namespace llvm {
@ -26,6 +28,10 @@ namespace mlirclang {
mlir::Operation *replaceFuncByOperation(mlir::FuncOp f, llvm::StringRef opName,
llvm::ArrayRef<mlir::Value> operands,
mlir::OpBuilder &b);
mlir::Operation *buildLinalgOp(const mlir::AbstractOperation *op,
llvm::ArrayRef<mlir::Value> operands,
llvm::ArrayRef<mlir::Type> results,
mlir::OpBuilder &b);
} // namespace mlirclang

View File

@ -0,0 +1,16 @@
// RUN: mlir-clang %s | FileCheck %s
#pragma lower_to(copy_op, "linalg.copy")
void copy_op(int a[3][3], int b[3][3]) {
for (int i = 0; i < 3; i++)
for (int j = 0; j < 3; j++)
b[i][j] = a[i][j];
}
int main() {
int a[3][3];
int b[3][3];
// CHECK: linalg.copy({{.*}}, {{.*}}) : memref<3x3xi32>, memref<3x3xi32>
copy_op(a, b);
return 0;
}

View File

@ -4,6 +4,7 @@
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
@ -129,6 +130,7 @@ int main(int argc, char **argv) {
context.getOrLoadDialect<mlir::omp::OpenMPDialect>();
context.getOrLoadDialect<mlir::math::MathDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
context.getOrLoadDialect<mlir::polygeist::PolygeistDialect>();
// MLIRContext context;