mirror of https://github.com/Qiskit/qiskit.git
Fully port FilterOpNodes to Rust (#13052)
* Fully port FilterOpNodes to Rust This commit ports the FilterOpNodes pass to rust. This pass is exceedingly simple and just runs a filter function over all the op nodes and removes nodes that match the filter. However, the API for the class exposes that filter function interface as a user provided Python callable. So for the current pass we need to retain that python callback. This limits the absolute performance of this pass because we're bottlenecked by calling python. Looking to the future, this commit adds a rust native method to DAGCircuit to perform this filtering with a rust predicate FnMut. It isn't leveraged by the python implementation because of layer mismatch for the efficient rust interface and Python working with `DAGOpNode` objects. A function using that interface is added to filter labeled nodes. In the preset pass manager we only use FilterOpNodes to remove nodes with a specific label (which is used to identify temporary barriers created by qiskit). In a follow up we should consider leveraging this new function to build a new pass specifically for this use case. Fixes #12263 Part of #12208 * Make filter_op_nodes() infallible The filter_op_nodes() method originally returned a Result<()> to handle a predicate that was fallible. This was because the original intent for the method was to use it with Python callbacks in the predicate. But because of differences between the rust API and the Python API this wasn't feasible as was originally planned. So this Result<()> return wasn't used anymore. This commit reworks it to make the filter_op_nodes() infallible and the predicate a user provides also only returns `bool` and not `Result<bool>`. * Rename filter_labelled_op to filter_labeled_op
This commit is contained in:
parent
254ba83dc6
commit
8982dbde15
|
@ -0,0 +1,63 @@
|
|||
// This code is part of Qiskit.
|
||||
//
|
||||
// (C) Copyright IBM 2024
|
||||
//
|
||||
// This code is licensed under the Apache License, Version 2.0. You may
|
||||
// obtain a copy of this license in the LICENSE.txt file in the root directory
|
||||
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
|
||||
//
|
||||
// Any modifications or derivative works of this code must retain this
|
||||
// copyright notice, and modified files need to carry a notice indicating
|
||||
// that they have been altered from the originals.
|
||||
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
|
||||
use qiskit_circuit::dag_circuit::DAGCircuit;
|
||||
use qiskit_circuit::packed_instruction::PackedInstruction;
|
||||
use rustworkx_core::petgraph::stable_graph::NodeIndex;
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(name = "filter_op_nodes")]
|
||||
pub fn py_filter_op_nodes(
|
||||
py: Python,
|
||||
dag: &mut DAGCircuit,
|
||||
predicate: &Bound<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
let callable = |node: NodeIndex| -> PyResult<bool> {
|
||||
let dag_op_node = dag.get_node(py, node)?;
|
||||
predicate.call1((dag_op_node,))?.extract()
|
||||
};
|
||||
let mut remove_nodes: Vec<NodeIndex> = Vec::new();
|
||||
for node in dag.op_nodes(true) {
|
||||
if !callable(node)? {
|
||||
remove_nodes.push(node);
|
||||
}
|
||||
}
|
||||
for node in remove_nodes {
|
||||
dag.remove_op_node(node);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove any nodes that have the provided label set
|
||||
///
|
||||
/// Args:
|
||||
/// dag (DAGCircuit): The dag circuit to filter the ops from
|
||||
/// label (str): The label to filter nodes on
|
||||
#[pyfunction]
|
||||
pub fn filter_labeled_op(dag: &mut DAGCircuit, label: String) {
|
||||
let predicate = |node: &PackedInstruction| -> bool {
|
||||
match node.label() {
|
||||
Some(inst_label) => inst_label != label,
|
||||
None => false,
|
||||
}
|
||||
};
|
||||
dag.filter_op_nodes(predicate);
|
||||
}
|
||||
|
||||
pub fn filter_op_nodes_mod(m: &Bound<PyModule>) -> PyResult<()> {
|
||||
m.add_wrapped(wrap_pyfunction!(py_filter_op_nodes))?;
|
||||
m.add_wrapped(wrap_pyfunction!(filter_labeled_op))?;
|
||||
Ok(())
|
||||
}
|
|
@ -22,6 +22,7 @@ pub mod dense_layout;
|
|||
pub mod edge_collections;
|
||||
pub mod error_map;
|
||||
pub mod euler_one_qubit_decomposer;
|
||||
pub mod filter_op_nodes;
|
||||
pub mod isometry;
|
||||
pub mod nlayout;
|
||||
pub mod optimize_1q_gates;
|
||||
|
|
|
@ -49,7 +49,9 @@ fn run_remove_diagonal_before_measure(dag: &mut DAGCircuit) -> PyResult<()> {
|
|||
let mut nodes_to_remove = Vec::new();
|
||||
for index in dag.op_nodes(true) {
|
||||
let node = &dag.dag[index];
|
||||
let NodeType::Operation(inst) = node else {panic!()};
|
||||
let NodeType::Operation(inst) = node else {
|
||||
panic!()
|
||||
};
|
||||
|
||||
if inst.op.name() == "measure" {
|
||||
let predecessor = (dag.quantum_predecessors(index))
|
||||
|
|
|
@ -5794,6 +5794,25 @@ impl DAGCircuit {
|
|||
}
|
||||
}
|
||||
|
||||
// Filter any nodes that don't match a given predicate function
|
||||
pub fn filter_op_nodes<F>(&mut self, mut predicate: F)
|
||||
where
|
||||
F: FnMut(&PackedInstruction) -> bool,
|
||||
{
|
||||
let mut remove_nodes: Vec<NodeIndex> = Vec::new();
|
||||
for node in self.op_nodes(true) {
|
||||
let NodeType::Operation(op) = &self.dag[node] else {
|
||||
unreachable!()
|
||||
};
|
||||
if !predicate(op) {
|
||||
remove_nodes.push(node);
|
||||
}
|
||||
}
|
||||
for node in remove_nodes {
|
||||
self.remove_op_node(node);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn op_nodes_by_py_type<'a>(
|
||||
&'a self,
|
||||
op: &'a Bound<PyType>,
|
||||
|
|
|
@ -553,6 +553,13 @@ impl PackedInstruction {
|
|||
.and_then(|extra| extra.condition.as_ref())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn label(&self) -> Option<&str> {
|
||||
self.extra_attrs
|
||||
.as_ref()
|
||||
.and_then(|extra| extra.label.as_deref())
|
||||
}
|
||||
|
||||
/// Build a reference to the Python-space operation object (the `Gate`, etc) packed into this
|
||||
/// instruction. This may construct the reference if the `PackedInstruction` is a standard
|
||||
/// gate with no already stored operation.
|
||||
|
|
|
@ -16,8 +16,9 @@ use qiskit_accelerate::{
|
|||
circuit_library::circuit_library, commutation_analysis::commutation_analysis,
|
||||
commutation_checker::commutation_checker, convert_2q_block_matrix::convert_2q_block_matrix,
|
||||
dense_layout::dense_layout, error_map::error_map,
|
||||
euler_one_qubit_decomposer::euler_one_qubit_decomposer, isometry::isometry, nlayout::nlayout,
|
||||
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval,
|
||||
euler_one_qubit_decomposer::euler_one_qubit_decomposer, filter_op_nodes::filter_op_nodes_mod,
|
||||
isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates,
|
||||
pauli_exp_val::pauli_expval,
|
||||
remove_diagonal_gates_before_measure::remove_diagonal_gates_before_measure, results::results,
|
||||
sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
|
||||
star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis,
|
||||
|
@ -46,6 +47,7 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
|
|||
add_submodule(m, dense_layout, "dense_layout")?;
|
||||
add_submodule(m, error_map, "error_map")?;
|
||||
add_submodule(m, euler_one_qubit_decomposer, "euler_one_qubit_decomposer")?;
|
||||
add_submodule(m, filter_op_nodes_mod, "filter_op_nodes")?;
|
||||
add_submodule(m, isometry, "isometry")?;
|
||||
add_submodule(m, nlayout, "nlayout")?;
|
||||
add_submodule(m, optimize_1q_gates, "optimize_1q_gates")?;
|
||||
|
|
|
@ -92,6 +92,7 @@ sys.modules["qiskit._accelerate.synthesis.clifford"] = _accelerate.synthesis.cli
|
|||
sys.modules["qiskit._accelerate.commutation_checker"] = _accelerate.commutation_checker
|
||||
sys.modules["qiskit._accelerate.commutation_analysis"] = _accelerate.commutation_analysis
|
||||
sys.modules["qiskit._accelerate.synthesis.linear_phase"] = _accelerate.synthesis.linear_phase
|
||||
sys.modules["qiskit._accelerate.filter_op_nodes"] = _accelerate.filter_op_nodes
|
||||
|
||||
from qiskit.exceptions import QiskitError, MissingOptionalLibraryError
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ from qiskit.dagcircuit import DAGCircuit, DAGOpNode
|
|||
from qiskit.transpiler.basepasses import TransformationPass
|
||||
from qiskit.transpiler.passes.utils import control_flow
|
||||
|
||||
from qiskit._accelerate.filter_op_nodes import filter_op_nodes
|
||||
|
||||
|
||||
class FilterOpNodes(TransformationPass):
|
||||
"""Remove all operations that match a filter function
|
||||
|
@ -59,7 +61,5 @@ class FilterOpNodes(TransformationPass):
|
|||
@control_flow.trivial_recurse
|
||||
def run(self, dag: DAGCircuit) -> DAGCircuit:
|
||||
"""Run the RemoveBarriers pass on `dag`."""
|
||||
for node in dag.op_nodes():
|
||||
if not self.predicate(node):
|
||||
dag.remove_op_node(node)
|
||||
filter_op_nodes(dag, self.predicate)
|
||||
return dag
|
||||
|
|
Loading…
Reference in New Issue