[PDLL] Add a `rewrite` statement to enable complex rewrites
The `rewrite` statement allows for rewriting a given root operation with a block of nested rewriters. The root operation is not implicitly erased or replaced, and any transformations to it must be expressed within the nested rewrite block. The inner body may contain any number of other rewrite statements, variables, or expressions. Differential Revision: https://reviews.llvm.org/D115299
This commit is contained in:
parent
12eebb8e37
commit
3ee44cb775
|
@ -278,6 +278,28 @@ private:
|
|||
friend class llvm::TrailingObjects<ReplaceStmt, Expr *>;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RewriteStmt
|
||||
|
||||
/// This statement represents an operation rewrite that contains a block of
|
||||
/// nested rewrite commands. This allows for building more complex operation
|
||||
/// rewrites that span across multiple statements, which may be unconnected.
|
||||
class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
|
||||
public:
|
||||
static RewriteStmt *create(Context &ctx, llvm::SMRange loc, Expr *rootOp,
|
||||
CompoundStmt *rewriteBody);
|
||||
|
||||
/// Return the compound rewrite body.
|
||||
CompoundStmt *getRewriteBody() const { return rewriteBody; }
|
||||
|
||||
private:
|
||||
RewriteStmt(llvm::SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
|
||||
: Base(loc, rootOp), rewriteBody(rewriteBody) {}
|
||||
|
||||
/// The body of nested rewriters within this statement.
|
||||
CompoundStmt *rewriteBody;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Expr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -909,7 +931,7 @@ inline bool Expr::classof(const Node *node) {
|
|||
}
|
||||
|
||||
inline bool OpRewriteStmt::classof(const Node *node) {
|
||||
return isa<EraseStmt, ReplaceStmt>(node);
|
||||
return isa<EraseStmt, ReplaceStmt, RewriteStmt>(node);
|
||||
}
|
||||
|
||||
inline bool Stmt::classof(const Node *node) {
|
||||
|
|
|
@ -76,6 +76,7 @@ private:
|
|||
void printImpl(const EraseStmt *stmt);
|
||||
void printImpl(const LetStmt *stmt);
|
||||
void printImpl(const ReplaceStmt *stmt);
|
||||
void printImpl(const RewriteStmt *stmt);
|
||||
|
||||
void printImpl(const AttributeExpr *expr);
|
||||
void printImpl(const DeclRefExpr *expr);
|
||||
|
@ -159,6 +160,7 @@ void NodePrinter::print(const Node *node) {
|
|||
.Case<
|
||||
// Statements.
|
||||
const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
|
||||
const RewriteStmt,
|
||||
|
||||
// Expressions.
|
||||
const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
|
||||
|
@ -197,6 +199,11 @@ void NodePrinter::printImpl(const ReplaceStmt *stmt) {
|
|||
printChildren("ReplValues", stmt->getReplExprs());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const RewriteStmt *stmt) {
|
||||
os << "RewriteStmt " << stmt << "\n";
|
||||
printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const AttributeExpr *expr) {
|
||||
os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
|
||||
}
|
||||
|
|
|
@ -99,6 +99,15 @@ ReplaceStmt *ReplaceStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp,
|
|||
return stmt;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RewriteStmt
|
||||
|
||||
RewriteStmt *RewriteStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp,
|
||||
CompoundStmt *rewriteBody) {
|
||||
return new (ctx.getAllocator().Allocate<RewriteStmt>())
|
||||
RewriteStmt(loc, rootOp, rewriteBody);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttributeExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -164,6 +164,7 @@ private:
|
|||
FailureOr<ast::EraseStmt *> parseEraseStmt();
|
||||
FailureOr<ast::LetStmt *> parseLetStmt();
|
||||
FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
|
||||
FailureOr<ast::RewriteStmt *> parseRewriteStmt();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Creation+Analysis
|
||||
|
@ -246,6 +247,9 @@ private:
|
|||
FailureOr<ast::ReplaceStmt *>
|
||||
createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
|
||||
MutableArrayRef<ast::Expr *> replValues);
|
||||
FailureOr<ast::RewriteStmt *>
|
||||
createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
|
||||
ast::CompoundStmt *rewriteBody);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Lexer Utilities
|
||||
|
@ -1156,6 +1160,9 @@ FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
|
|||
case Token::kw_replace:
|
||||
stmt = parseReplaceStmt();
|
||||
break;
|
||||
case Token::kw_rewrite:
|
||||
stmt = parseRewriteStmt();
|
||||
break;
|
||||
default:
|
||||
stmt = parseExpr();
|
||||
break;
|
||||
|
@ -1307,6 +1314,32 @@ FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
|
|||
return createReplaceStmt(loc, *rootOp, replValues);
|
||||
}
|
||||
|
||||
FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
|
||||
llvm::SMRange loc = curToken.getLoc();
|
||||
consumeToken(Token::kw_rewrite);
|
||||
|
||||
// Parse the root operation.
|
||||
FailureOr<ast::Expr *> rootOp = parseExpr();
|
||||
if (failed(rootOp))
|
||||
return failure();
|
||||
|
||||
if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
|
||||
return failure();
|
||||
|
||||
if (curToken.isNot(Token::l_brace))
|
||||
return emitError("expected `{` to start rewrite body");
|
||||
|
||||
// The rewrite body of this statement is within a rewrite context.
|
||||
llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
|
||||
ParserContext::Rewrite);
|
||||
|
||||
FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
|
||||
if (failed(rewriteBody))
|
||||
return failure();
|
||||
|
||||
return createRewriteStmt(loc, *rootOp, *rewriteBody);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Creation+Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1647,6 +1680,20 @@ Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
|
|||
return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
|
||||
}
|
||||
|
||||
FailureOr<ast::RewriteStmt *>
|
||||
Parser::createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
|
||||
ast::CompoundStmt *rewriteBody) {
|
||||
// Check that root is an Operation.
|
||||
ast::Type rootType = rootOp->getType();
|
||||
if (!rootType.isa<ast::OperationType>()) {
|
||||
return emitError(
|
||||
rootOp->getLoc(),
|
||||
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
|
||||
}
|
||||
|
||||
return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Parser
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -273,3 +273,37 @@ Pattern {
|
|||
// CHECK: expected dialect namespace
|
||||
replace op<>(input: Value) with op<>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// `rewrite`
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `Op` expression
|
||||
rewrite attr<""> with { op<toy.reshape>; };
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `with` before rewrite body
|
||||
rewrite op<>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `{` to start rewrite body
|
||||
rewrite op<> with;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected dialect namespace
|
||||
rewrite root: Op with {
|
||||
op<>;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -182,3 +182,25 @@ Pattern {
|
|||
Pattern {
|
||||
replace _: Op with (_: Value, _: ValueRange, op<my_dialect.foo>);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RewriteStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-RewriteStmt
|
||||
// CHECK: |-DeclRefExpr {{.*}} Type<Op>
|
||||
// CHECK: `-CompoundStmt
|
||||
// CHECK: |-OperationExpr {{.*}} Type<Op<my_dialect.some_op>>
|
||||
// CHECK: `-ReplaceStmt {{.*}}
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<Op>
|
||||
// CHECK: `ReplValues`
|
||||
// CHECK: `-OperationExpr {{.*}} Type<Op<my_dialect.foo>>
|
||||
Pattern {
|
||||
rewrite root: Op with {
|
||||
op<my_dialect.some_op>;
|
||||
replace root with op<my_dialect.foo>;
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue