Use LowerTo pragma to generate a linalg.copyOp
- Refactor pragma lower to handling
This commit is contained in:
parent
3b54e6eea4
commit
e9b551c1e7
|
@ -6,6 +6,7 @@
|
|||
#include "clang/Lex/LexDiagnostic.h"
|
||||
#include "clang/Lex/LiteralSupport.h"
|
||||
#include "clang/Lex/Preprocessor.h"
|
||||
#include "clang/Parse/ParseDiagnostic.h"
|
||||
#include "clang/Sema/Sema.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -31,34 +32,70 @@ public:
|
|||
/// the function decl that lower_to is attached to with that MLIR function
|
||||
/// symbol in the class-referenced dictionary.
|
||||
///
|
||||
/// TODO: Handle assertions properly.
|
||||
void HandlePragma(Preprocessor &PP, PragmaIntroducer Introducer,
|
||||
Token &PragmaTok) {
|
||||
Token Tok;
|
||||
PP.Lex(Tok); // lparen
|
||||
assert(Tok.is(tok::l_paren) && "lower_to should start with '('.");
|
||||
if (Tok.isNot(tok::l_paren)) {
|
||||
PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_lparen)
|
||||
<< "lower_to";
|
||||
return;
|
||||
}
|
||||
|
||||
Token FuncIdTok; // function identifier
|
||||
PP.Lex(FuncIdTok);
|
||||
assert(FuncIdTok.is(tok::identifier) &&
|
||||
"The first argument of lower_to should be an identifier.");
|
||||
Token PrevTok = Tok;
|
||||
llvm::StringRef FuncId = llvm::StringRef();
|
||||
llvm::StringRef SymbolName = llvm::StringRef();
|
||||
while (Tok.isNot(tok::eod)) {
|
||||
Token CurrentTok;
|
||||
PP.Lex(CurrentTok);
|
||||
|
||||
llvm::StringRef FuncId = FuncIdTok.getIdentifierInfo()->getName();
|
||||
// rparen.
|
||||
if (PrevTok.is(tok::string_literal)) {
|
||||
if (CurrentTok.isNot(tok::r_paren)) {
|
||||
PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_rparen)
|
||||
<< "lower_to";
|
||||
return;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
PP.Lex(Tok); // comma
|
||||
assert(Tok.is(tok::comma) && "The first and second argument of lower_to "
|
||||
"should be separated by a comma.");
|
||||
// function identifier.
|
||||
if (PrevTok.is(tok::l_paren)) {
|
||||
if (CurrentTok.isNot(tok::identifier)) {
|
||||
PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_identifier)
|
||||
<< "lower_to";
|
||||
return;
|
||||
} else {
|
||||
FuncId = CurrentTok.getIdentifierInfo()->getName();
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the string literal argument, which is the MLIR function symbol.
|
||||
// comma.
|
||||
if (PrevTok.is(tok::identifier)) {
|
||||
if (CurrentTok.isNot(tok::comma)) {
|
||||
PP.Diag(Tok.getLocation(), diag::warn_pragma_expected_punc)
|
||||
<< "lower_to";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// string literal, which is the MLIR function symbol.
|
||||
if (PrevTok.is(tok::comma)) {
|
||||
if (CurrentTok.isNot(tok::string_literal)) {
|
||||
PP.Diag(CurrentTok.getLocation(),
|
||||
diag::warn_pragma_expected_section_name)
|
||||
<< "lower to";
|
||||
return;
|
||||
} else {
|
||||
SmallVector<Token, 1> SymbolToks;
|
||||
Token SymbolTok;
|
||||
PP.Lex(SymbolTok);
|
||||
assert(SymbolTok.is(tok::string_literal) &&
|
||||
"The second argument of lower_to should be a string literal.");
|
||||
SymbolToks.push_back(SymbolTok);
|
||||
StringRef SymbolName = StringLiteralParser(SymbolToks, PP).GetString();
|
||||
PP.Lex(Tok); // rparen
|
||||
assert(Tok.is(tok::r_paren) && "lower_to should end with '('.");
|
||||
SymbolToks.push_back(CurrentTok);
|
||||
SymbolName = StringLiteralParser(SymbolToks, PP).GetString();
|
||||
}
|
||||
}
|
||||
|
||||
PrevTok = CurrentTok;
|
||||
}
|
||||
|
||||
// Link SymbolName with the function.
|
||||
auto result = Info.SymbolTable.try_emplace(FuncId, SymbolName);
|
||||
|
|
|
@ -9,9 +9,25 @@
|
|||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace llvm;
|
||||
|
||||
Operation *mlirclang::buildLinalgOp(const AbstractOperation *op,
|
||||
ArrayRef<Value> operands,
|
||||
ArrayRef<Type> opResultTypes,
|
||||
OpBuilder &b) {
|
||||
StringRef name = op->name;
|
||||
if (name.compare("linalg.copy") == 0) {
|
||||
return b.create<linalg::CopyOp>(b.getUnknownLoc(), operands[1],
|
||||
operands[0]);
|
||||
} else {
|
||||
llvm::report_fatal_error(llvm::Twine("builder not supported for: ") + name);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Operation *mlirclang::replaceFuncByOperation(FuncOp f, StringRef opName,
|
||||
ArrayRef<Value> operands,
|
||||
OpBuilder &b) {
|
||||
|
@ -21,6 +37,9 @@ Operation *mlirclang::replaceFuncByOperation(FuncOp f, StringRef opName,
|
|||
|
||||
const AbstractOperation *op = AbstractOperation::lookup(opName, ctx);
|
||||
|
||||
if (opName.startswith("linalg"))
|
||||
return buildLinalgOp(op, operands, f.getCallableResults(), b);
|
||||
|
||||
// NOTE: The attributes of the provided FuncOp is ignored.
|
||||
OperationState opState(b.getUnknownLoc(), op->name, ValueRange(operands),
|
||||
f.getCallableResults(), {});
|
||||
|
|
|
@ -10,6 +10,8 @@ class Operation;
|
|||
class FuncOp;
|
||||
class Value;
|
||||
class OpBuilder;
|
||||
class AbstractOperation;
|
||||
class Type;
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
@ -26,6 +28,10 @@ namespace mlirclang {
|
|||
mlir::Operation *replaceFuncByOperation(mlir::FuncOp f, llvm::StringRef opName,
|
||||
llvm::ArrayRef<mlir::Value> operands,
|
||||
mlir::OpBuilder &b);
|
||||
mlir::Operation *buildLinalgOp(const mlir::AbstractOperation *op,
|
||||
llvm::ArrayRef<mlir::Value> operands,
|
||||
llvm::ArrayRef<mlir::Type> results,
|
||||
mlir::OpBuilder &b);
|
||||
|
||||
} // namespace mlirclang
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
// RUN: mlir-clang %s | FileCheck %s
|
||||
|
||||
#pragma lower_to(copy_op, "linalg.copy")
|
||||
void copy_op(int a[3][3], int b[3][3]) {
|
||||
for (int i = 0; i < 3; i++)
|
||||
for (int j = 0; j < 3; j++)
|
||||
b[i][j] = a[i][j];
|
||||
}
|
||||
|
||||
int main() {
|
||||
int a[3][3];
|
||||
int b[3][3];
|
||||
// CHECK: linalg.copy({{.*}}, {{.*}}) : memref<3x3xi32>, memref<3x3xi32>
|
||||
copy_op(a, b);
|
||||
return 0;
|
||||
}
|
|
@ -4,6 +4,7 @@
|
|||
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
|
||||
#include "mlir/Dialect/Affine/Passes.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
@ -129,6 +130,7 @@ int main(int argc, char **argv) {
|
|||
context.getOrLoadDialect<mlir::omp::OpenMPDialect>();
|
||||
context.getOrLoadDialect<mlir::math::MathDialect>();
|
||||
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
|
||||
context.getOrLoadDialect<mlir::polygeist::PolygeistDialect>();
|
||||
// MLIRContext context;
|
||||
|
||||
|
|
Loading…
Reference in New Issue