[python] Add `walk_with_filter` to walk subset of IR (#7591)

This adds `walk_with_filter` python method to invoke callbacks
only for subset of operations. We require GIL to call python function
so this could improve performance of the tools based on Python API.

```
Starting walk_with_filter
walk_with_filter elapsed time: 0.005462 seconds cnt=1
Starting operation.walk
operation.walk elapsed time: 1.061360 seconds cnt=2
```
This commit is contained in:
Hideto Ueno 2024-09-10 15:13:59 +09:00 committed by GitHub
parent 6638aafebc
commit c433d61523
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 189 additions and 11 deletions

View File

@ -0,0 +1,76 @@
# REQUIRES: bindings_python
# RUN: %PYTHON% %s | FileCheck %s
import circt
from circt.support import walk_with_filter
from circt.dialects import hw
from circt.ir import Context, Module, WalkOrder, WalkResult
def test_walk_with_filter():
ctx = Context()
circt.register_dialects(ctx)
module = Module.parse(
r"""
builtin.module {
hw.module @f() {
hw.output
}
}
""",
ctx,
)
def callback(op):
print(op.name)
return WalkResult.ADVANCE
# Test post-order walk.
# CHECK: Post-order
# CHECK-NEXT: hw.output
# CHECK-NEXT: hw.module
# CHECK-NOT: builtin.module
print("Post-order")
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback,
WalkOrder.POST_ORDER)
# Test pre-order walk.
# CHECK-NEXT: Pre-order
# CHECK-NOT: builtin.module
# CHECK-NEXT: hw.module
# CHECK-NEXT: hw.output
print("Pre-order")
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback,
WalkOrder.PRE_ORDER)
# Test interrupt.
# CHECK-NEXT: Interrupt post-order
# CHECK-NEXT: hw.output
print("Interrupt post-order")
def interrupt_callback(op):
print(op.name)
return WalkResult.INTERRUPT
walk_with_filter(module.operation, [hw.OutputOp], interrupt_callback,
WalkOrder.POST_ORDER)
# Test exception.
# CHECK: Exception
# CHECK-NEXT: hw.output
# CHECK-NEXT: Exception raised
print("Exception")
def exception_callback(op):
print(op.name)
raise ValueError
return WalkResult.ADVANCE
try:
walk_with_filter(module.operation, [hw.OutputOp], exception_callback,
WalkOrder.POST_ORDER)
except RuntimeError:
print("Exception raised")
test_walk_with_filter()

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "DialectModules.h"
#include "CIRCTModules.h"
#include "circt-c/Conversion.h"
#include "circt-c/Dialect/Comb.h"
@ -142,4 +142,6 @@ PYBIND11_MODULE(_circt, m) {
circt::python::populateDialectOMSubmodule(om);
py::module sv = m.def_submodule("_sv", "SV API");
circt::python::populateDialectSVSubmodule(sv);
py::module support = m.def_submodule("_support", "CIRCT support");
circt::python::populateSupportSubmodule(support);
}

View File

@ -1,4 +1,4 @@
//===- DialectModules.h - Populate submodules -----------------------------===//
//===- CIRCTModules.h - Populate submodules -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
// Functions to populate each dialect's submodule (if provided).
// Functions to populate submodules in CIRCT (if provided).
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H
#define CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H
#ifndef CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H
#define CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H
#include <pybind11/pybind11.h>
@ -24,6 +24,7 @@ void populateDialectMSFTSubmodule(pybind11::module &m);
void populateDialectOMSubmodule(pybind11::module &m);
void populateDialectSeqSubmodule(pybind11::module &m);
void populateDialectSVSubmodule(pybind11::module &m);
void populateSupportSubmodule(pybind11::module &m);
} // namespace python
} // namespace circt

View File

@ -20,6 +20,7 @@ declare_mlir_python_extension(CIRCTBindingsPythonExtension
OMModule.cpp
MSFTModule.cpp
SeqModule.cpp
SupportModule.cpp
SVModule.cpp
EMBED_CAPI_LINK_LIBS
CIRCTCAPIComb

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "DialectModules.h"
#include "CIRCTModules.h"
#include "circt/Dialect/ESI/ESIDialect.h"

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "DialectModules.h"
#include "CIRCTModules.h"
#include "circt-c/Dialect/HW.h"

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "DialectModules.h"
#include "CIRCTModules.h"
#include "circt-c/Dialect/MSFT.h"
#include "circt/Dialect/MSFT/MSFTDialect.h"

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "DialectModules.h"
#include "CIRCTModules.h"
#include "circt-c/Dialect/HW.h"
#include "circt-c/Dialect/OM.h"
#include "mlir-c/BuiltinAttributes.h"

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "DialectModules.h"
#include "CIRCTModules.h"
#include "circt-c/Dialect/SV.h"
#include "mlir-c/Bindings/Python/Interop.h"

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "DialectModules.h"
#include "CIRCTModules.h"
#include "circt-c/Dialect/Seq.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"

View File

@ -0,0 +1,86 @@
//===- SupportModule.cpp - Support API pybind module ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "CIRCTModules.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "PybindUtils.h"
#include "mlir-c/Support.h"
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
namespace py = pybind11;
using namespace circt;
using namespace mlir::python::adaptors;
/// Populate the support python module.
void circt::python::populateSupportSubmodule(py::module &m) {
m.doc() = "CIRCT Python utils";
// Walk with filter.
m.def(
"_walk_with_filter",
[](MlirOperation operation, std::vector<std::string> op_names,
std::function<MlirWalkResult(MlirOperation)> callback,
MlirWalkOrder walkOrder) {
struct UserData {
std::function<MlirWalkResult(MlirOperation)> callback;
bool gotException;
std::string exceptionWhat;
py::object exceptionType;
std::vector<MlirIdentifier> op_names;
};
std::vector<MlirIdentifier> op_names_identifiers;
// Construct MlirIdentifier from string to perform pointer comparison.
for (auto &op_name : op_names)
op_names_identifiers.push_back(mlirIdentifierGet(
mlirOperationGetContext(operation),
mlirStringRefCreateFromCString(op_name.c_str())));
UserData userData{callback, false, {}, {}, op_names_identifiers};
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
UserData *calleeUserData = static_cast<UserData *>(userData);
auto op_name = mlirOperationGetName(op);
// Check if the operation name is in the filter.
bool inFilter = false;
for (auto &op_name_identifier : calleeUserData->op_names) {
if (mlirIdentifierEqual(op_name, op_name_identifier)) {
inFilter = true;
break;
}
}
// If the operation name is not in the filter, skip it.
if (!inFilter)
return MlirWalkResult::MlirWalkResultAdvance;
try {
return (calleeUserData->callback)(op);
} catch (py::error_already_set &e) {
calleeUserData->gotException = true;
calleeUserData->exceptionWhat = e.what();
calleeUserData->exceptionType = e.type();
return MlirWalkResult::MlirWalkResultInterrupt;
}
};
mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
if (userData.gotException) {
std::string message("Exception raised in callback: ");
message.append(userData.exceptionWhat);
throw std::runtime_error(message);
}
},
py::arg("op"), py::arg("op_names"), py::arg("callback"),
py::arg("walk_order"));
}

View File

@ -4,6 +4,8 @@
from . import ir
from ._mlir_libs._circt._support import _walk_with_filter
from .ir import Operation
from contextlib import AbstractContextManager
from contextvars import ContextVar
from typing import List
@ -409,3 +411,13 @@ class NamedValueOpView:
def operation(self):
"""Get the operation associated with this builder."""
return self.opview.operation
# Helper function to walk operation with a filter on operation names.
# `op_views` is a list of operation views to visit. This is a wrapper
# around the C++ implementation of walk_with_filter.
def walk_with_filter(operation: Operation, op_views: List[ir.OpView], callback,
walk_order):
op_names_identifiers = [name.OPERATION_NAME for name in op_views]
return _walk_with_filter(operation, op_names_identifiers, callback,
walk_order)