[PDLL] Add a `rewrite` statement to enable complex rewrites

The `rewrite` statement allows for rewriting a given root
operation with a block of nested rewriters. The root operation is
not implicitly erased or replaced, and any transformations to it
must be expressed within the nested rewrite block. The inner body
may contain any number of other rewrite statements, variables, or
expressions.

Differential Revision: https://reviews.llvm.org/D115299
This commit is contained in:
River Riddle 2021-12-16 01:50:03 +00:00
parent 12eebb8e37
commit 3ee44cb775
6 changed files with 142 additions and 1 deletions

View File

@ -278,6 +278,28 @@ private:
friend class llvm::TrailingObjects<ReplaceStmt, Expr *>;
};
//===----------------------------------------------------------------------===//
// RewriteStmt
/// This statement represents an operation rewrite that contains a block of
/// nested rewrite commands. This allows for building more complex operation
/// rewrites that span across multiple statements, which may be unconnected.
class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
public:
static RewriteStmt *create(Context &ctx, llvm::SMRange loc, Expr *rootOp,
CompoundStmt *rewriteBody);
/// Return the compound rewrite body.
CompoundStmt *getRewriteBody() const { return rewriteBody; }
private:
RewriteStmt(llvm::SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
: Base(loc, rootOp), rewriteBody(rewriteBody) {}
/// The body of nested rewriters within this statement.
CompoundStmt *rewriteBody;
};
//===----------------------------------------------------------------------===//
// Expr
//===----------------------------------------------------------------------===//
@ -909,7 +931,7 @@ inline bool Expr::classof(const Node *node) {
}
inline bool OpRewriteStmt::classof(const Node *node) {
return isa<EraseStmt, ReplaceStmt>(node);
return isa<EraseStmt, ReplaceStmt, RewriteStmt>(node);
}
inline bool Stmt::classof(const Node *node) {

View File

@ -76,6 +76,7 @@ private:
void printImpl(const EraseStmt *stmt);
void printImpl(const LetStmt *stmt);
void printImpl(const ReplaceStmt *stmt);
void printImpl(const RewriteStmt *stmt);
void printImpl(const AttributeExpr *expr);
void printImpl(const DeclRefExpr *expr);
@ -159,6 +160,7 @@ void NodePrinter::print(const Node *node) {
.Case<
// Statements.
const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
const RewriteStmt,
// Expressions.
const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
@ -197,6 +199,11 @@ void NodePrinter::printImpl(const ReplaceStmt *stmt) {
printChildren("ReplValues", stmt->getReplExprs());
}
void NodePrinter::printImpl(const RewriteStmt *stmt) {
os << "RewriteStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
}
void NodePrinter::printImpl(const AttributeExpr *expr) {
os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
}

View File

@ -99,6 +99,15 @@ ReplaceStmt *ReplaceStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp,
return stmt;
}
//===----------------------------------------------------------------------===//
// RewriteStmt
RewriteStmt *RewriteStmt::create(Context &ctx, llvm::SMRange loc, Expr *rootOp,
CompoundStmt *rewriteBody) {
return new (ctx.getAllocator().Allocate<RewriteStmt>())
RewriteStmt(loc, rootOp, rewriteBody);
}
//===----------------------------------------------------------------------===//
// AttributeExpr
//===----------------------------------------------------------------------===//

View File

@ -164,6 +164,7 @@ private:
FailureOr<ast::EraseStmt *> parseEraseStmt();
FailureOr<ast::LetStmt *> parseLetStmt();
FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
FailureOr<ast::RewriteStmt *> parseRewriteStmt();
//===--------------------------------------------------------------------===//
// Creation+Analysis
@ -246,6 +247,9 @@ private:
FailureOr<ast::ReplaceStmt *>
createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
MutableArrayRef<ast::Expr *> replValues);
FailureOr<ast::RewriteStmt *>
createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
ast::CompoundStmt *rewriteBody);
//===--------------------------------------------------------------------===//
// Lexer Utilities
@ -1156,6 +1160,9 @@ FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
case Token::kw_replace:
stmt = parseReplaceStmt();
break;
case Token::kw_rewrite:
stmt = parseRewriteStmt();
break;
default:
stmt = parseExpr();
break;
@ -1307,6 +1314,32 @@ FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
return createReplaceStmt(loc, *rootOp, replValues);
}
FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
llvm::SMRange loc = curToken.getLoc();
consumeToken(Token::kw_rewrite);
// Parse the root operation.
FailureOr<ast::Expr *> rootOp = parseExpr();
if (failed(rootOp))
return failure();
if (failed(parseToken(Token::kw_with, "expected `with` before rewrite body")))
return failure();
if (curToken.isNot(Token::l_brace))
return emitError("expected `{` to start rewrite body");
// The rewrite body of this statement is within a rewrite context.
llvm::SaveAndRestore<ParserContext> saveCtx(parserContext,
ParserContext::Rewrite);
FailureOr<ast::CompoundStmt *> rewriteBody = parseCompoundStmt();
if (failed(rewriteBody))
return failure();
return createRewriteStmt(loc, *rootOp, *rewriteBody);
}
//===----------------------------------------------------------------------===//
// Creation+Analysis
//===----------------------------------------------------------------------===//
@ -1647,6 +1680,20 @@ Parser::createReplaceStmt(llvm::SMRange loc, ast::Expr *rootOp,
return ast::ReplaceStmt::create(ctx, loc, rootOp, replValues);
}
FailureOr<ast::RewriteStmt *>
Parser::createRewriteStmt(llvm::SMRange loc, ast::Expr *rootOp,
ast::CompoundStmt *rewriteBody) {
// Check that root is an Operation.
ast::Type rootType = rootOp->getType();
if (!rootType.isa<ast::OperationType>()) {
return emitError(
rootOp->getLoc(),
llvm::formatv("expected `Op` expression, but got `{0}`", rootType));
}
return ast::RewriteStmt::create(ctx, loc, rootOp, rewriteBody);
}
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//

View File

@ -273,3 +273,37 @@ Pattern {
// CHECK: expected dialect namespace
replace op<>(input: Value) with op<>;
}
// -----
//===----------------------------------------------------------------------===//
// `rewrite`
//===----------------------------------------------------------------------===//
Pattern {
// CHECK: expected `Op` expression
rewrite attr<""> with { op<toy.reshape>; };
}
// -----
Pattern {
// CHECK: expected `with` before rewrite body
rewrite op<>;
}
// -----
Pattern {
// CHECK: expected `{` to start rewrite body
rewrite op<> with;
}
// -----
Pattern {
// CHECK: expected dialect namespace
rewrite root: Op with {
op<>;
};
}

View File

@ -182,3 +182,25 @@ Pattern {
Pattern {
replace _: Op with (_: Value, _: ValueRange, op<my_dialect.foo>);
}
// -----
//===----------------------------------------------------------------------===//
// RewriteStmt
//===----------------------------------------------------------------------===//
// CHECK: Module
// CHECK: `-RewriteStmt
// CHECK: |-DeclRefExpr {{.*}} Type<Op>
// CHECK: `-CompoundStmt
// CHECK: |-OperationExpr {{.*}} Type<Op<my_dialect.some_op>>
// CHECK: `-ReplaceStmt {{.*}}
// CHECK: `-DeclRefExpr {{.*}} Type<Op>
// CHECK: `ReplValues`
// CHECK: `-OperationExpr {{.*}} Type<Op<my_dialect.foo>>
Pattern {
rewrite root: Op with {
op<my_dialect.some_op>;
replace root with op<my_dialect.foo>;
};
}