mirror of https://github.com/llvm/circt.git
[python] Add `walk_with_filter` to walk subset of IR (#7591)
This adds `walk_with_filter` python method to invoke callbacks only for subset of operations. We require GIL to call python function so this could improve performance of the tools based on Python API. ``` Starting walk_with_filter walk_with_filter elapsed time: 0.005462 seconds cnt=1 Starting operation.walk operation.walk elapsed time: 1.061360 seconds cnt=2 ```
This commit is contained in:
parent
6638aafebc
commit
c433d61523
|
@ -0,0 +1,76 @@
|
|||
# REQUIRES: bindings_python
|
||||
# RUN: %PYTHON% %s | FileCheck %s
|
||||
|
||||
import circt
|
||||
from circt.support import walk_with_filter
|
||||
from circt.dialects import hw
|
||||
from circt.ir import Context, Module, WalkOrder, WalkResult
|
||||
|
||||
|
||||
def test_walk_with_filter():
|
||||
ctx = Context()
|
||||
circt.register_dialects(ctx)
|
||||
module = Module.parse(
|
||||
r"""
|
||||
builtin.module {
|
||||
hw.module @f() {
|
||||
hw.output
|
||||
}
|
||||
}
|
||||
""",
|
||||
ctx,
|
||||
)
|
||||
|
||||
def callback(op):
|
||||
print(op.name)
|
||||
return WalkResult.ADVANCE
|
||||
|
||||
# Test post-order walk.
|
||||
# CHECK: Post-order
|
||||
# CHECK-NEXT: hw.output
|
||||
# CHECK-NEXT: hw.module
|
||||
# CHECK-NOT: builtin.module
|
||||
print("Post-order")
|
||||
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback,
|
||||
WalkOrder.POST_ORDER)
|
||||
|
||||
# Test pre-order walk.
|
||||
# CHECK-NEXT: Pre-order
|
||||
# CHECK-NOT: builtin.module
|
||||
# CHECK-NEXT: hw.module
|
||||
# CHECK-NEXT: hw.output
|
||||
print("Pre-order")
|
||||
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback,
|
||||
WalkOrder.PRE_ORDER)
|
||||
|
||||
# Test interrupt.
|
||||
# CHECK-NEXT: Interrupt post-order
|
||||
# CHECK-NEXT: hw.output
|
||||
print("Interrupt post-order")
|
||||
|
||||
def interrupt_callback(op):
|
||||
print(op.name)
|
||||
return WalkResult.INTERRUPT
|
||||
|
||||
walk_with_filter(module.operation, [hw.OutputOp], interrupt_callback,
|
||||
WalkOrder.POST_ORDER)
|
||||
|
||||
# Test exception.
|
||||
# CHECK: Exception
|
||||
# CHECK-NEXT: hw.output
|
||||
# CHECK-NEXT: Exception raised
|
||||
print("Exception")
|
||||
|
||||
def exception_callback(op):
|
||||
print(op.name)
|
||||
raise ValueError
|
||||
return WalkResult.ADVANCE
|
||||
|
||||
try:
|
||||
walk_with_filter(module.operation, [hw.OutputOp], exception_callback,
|
||||
WalkOrder.POST_ORDER)
|
||||
except RuntimeError:
|
||||
print("Exception raised")
|
||||
|
||||
|
||||
test_walk_with_filter()
|
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DialectModules.h"
|
||||
#include "CIRCTModules.h"
|
||||
|
||||
#include "circt-c/Conversion.h"
|
||||
#include "circt-c/Dialect/Comb.h"
|
||||
|
@ -142,4 +142,6 @@ PYBIND11_MODULE(_circt, m) {
|
|||
circt::python::populateDialectOMSubmodule(om);
|
||||
py::module sv = m.def_submodule("_sv", "SV API");
|
||||
circt::python::populateDialectSVSubmodule(sv);
|
||||
py::module support = m.def_submodule("_support", "CIRCT support");
|
||||
circt::python::populateSupportSubmodule(support);
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- DialectModules.h - Populate submodules -----------------------------===//
|
||||
//===- CIRCTModules.h - Populate submodules -------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,12 +6,12 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Functions to populate each dialect's submodule (if provided).
|
||||
// Functions to populate submodules in CIRCT (if provided).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H
|
||||
#define CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H
|
||||
#ifndef CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H
|
||||
#define CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
|
@ -24,6 +24,7 @@ void populateDialectMSFTSubmodule(pybind11::module &m);
|
|||
void populateDialectOMSubmodule(pybind11::module &m);
|
||||
void populateDialectSeqSubmodule(pybind11::module &m);
|
||||
void populateDialectSVSubmodule(pybind11::module &m);
|
||||
void populateSupportSubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace circt
|
|
@ -20,6 +20,7 @@ declare_mlir_python_extension(CIRCTBindingsPythonExtension
|
|||
OMModule.cpp
|
||||
MSFTModule.cpp
|
||||
SeqModule.cpp
|
||||
SupportModule.cpp
|
||||
SVModule.cpp
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
CIRCTCAPIComb
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DialectModules.h"
|
||||
#include "CIRCTModules.h"
|
||||
|
||||
#include "circt/Dialect/ESI/ESIDialect.h"
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DialectModules.h"
|
||||
#include "CIRCTModules.h"
|
||||
|
||||
#include "circt-c/Dialect/HW.h"
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DialectModules.h"
|
||||
#include "CIRCTModules.h"
|
||||
|
||||
#include "circt-c/Dialect/MSFT.h"
|
||||
#include "circt/Dialect/MSFT/MSFTDialect.h"
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DialectModules.h"
|
||||
#include "CIRCTModules.h"
|
||||
#include "circt-c/Dialect/HW.h"
|
||||
#include "circt-c/Dialect/OM.h"
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DialectModules.h"
|
||||
#include "CIRCTModules.h"
|
||||
|
||||
#include "circt-c/Dialect/SV.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DialectModules.h"
|
||||
#include "CIRCTModules.h"
|
||||
|
||||
#include "circt-c/Dialect/Seq.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
//===- SupportModule.cpp - Support API pybind module ----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "CIRCTModules.h"
|
||||
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
|
||||
#include "PybindUtils.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace circt;
|
||||
using namespace mlir::python::adaptors;
|
||||
|
||||
/// Populate the support python module.
|
||||
void circt::python::populateSupportSubmodule(py::module &m) {
|
||||
m.doc() = "CIRCT Python utils";
|
||||
// Walk with filter.
|
||||
m.def(
|
||||
"_walk_with_filter",
|
||||
[](MlirOperation operation, std::vector<std::string> op_names,
|
||||
std::function<MlirWalkResult(MlirOperation)> callback,
|
||||
MlirWalkOrder walkOrder) {
|
||||
struct UserData {
|
||||
std::function<MlirWalkResult(MlirOperation)> callback;
|
||||
bool gotException;
|
||||
std::string exceptionWhat;
|
||||
py::object exceptionType;
|
||||
std::vector<MlirIdentifier> op_names;
|
||||
};
|
||||
|
||||
std::vector<MlirIdentifier> op_names_identifiers;
|
||||
|
||||
// Construct MlirIdentifier from string to perform pointer comparison.
|
||||
for (auto &op_name : op_names)
|
||||
op_names_identifiers.push_back(mlirIdentifierGet(
|
||||
mlirOperationGetContext(operation),
|
||||
mlirStringRefCreateFromCString(op_name.c_str())));
|
||||
|
||||
UserData userData{callback, false, {}, {}, op_names_identifiers};
|
||||
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
|
||||
void *userData) {
|
||||
UserData *calleeUserData = static_cast<UserData *>(userData);
|
||||
auto op_name = mlirOperationGetName(op);
|
||||
|
||||
// Check if the operation name is in the filter.
|
||||
bool inFilter = false;
|
||||
for (auto &op_name_identifier : calleeUserData->op_names) {
|
||||
if (mlirIdentifierEqual(op_name, op_name_identifier)) {
|
||||
inFilter = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If the operation name is not in the filter, skip it.
|
||||
if (!inFilter)
|
||||
return MlirWalkResult::MlirWalkResultAdvance;
|
||||
|
||||
try {
|
||||
return (calleeUserData->callback)(op);
|
||||
} catch (py::error_already_set &e) {
|
||||
calleeUserData->gotException = true;
|
||||
calleeUserData->exceptionWhat = e.what();
|
||||
calleeUserData->exceptionType = e.type();
|
||||
return MlirWalkResult::MlirWalkResultInterrupt;
|
||||
}
|
||||
};
|
||||
mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
|
||||
if (userData.gotException) {
|
||||
std::string message("Exception raised in callback: ");
|
||||
message.append(userData.exceptionWhat);
|
||||
throw std::runtime_error(message);
|
||||
}
|
||||
},
|
||||
py::arg("op"), py::arg("op_names"), py::arg("callback"),
|
||||
py::arg("walk_order"));
|
||||
}
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
from . import ir
|
||||
|
||||
from ._mlir_libs._circt._support import _walk_with_filter
|
||||
from .ir import Operation
|
||||
from contextlib import AbstractContextManager
|
||||
from contextvars import ContextVar
|
||||
from typing import List
|
||||
|
@ -409,3 +411,13 @@ class NamedValueOpView:
|
|||
def operation(self):
|
||||
"""Get the operation associated with this builder."""
|
||||
return self.opview.operation
|
||||
|
||||
|
||||
# Helper function to walk operation with a filter on operation names.
|
||||
# `op_views` is a list of operation views to visit. This is a wrapper
|
||||
# around the C++ implementation of walk_with_filter.
|
||||
def walk_with_filter(operation: Operation, op_views: List[ir.OpView], callback,
|
||||
walk_order):
|
||||
op_names_identifiers = [name.OPERATION_NAME for name in op_views]
|
||||
return _walk_with_filter(operation, op_names_identifiers, callback,
|
||||
walk_order)
|
||||
|
|
Loading…
Reference in New Issue