Avoid Python op creation in BasisTranslator (#12705)

This commit updates the BasisTranslator transpiler pass. It builds off
of #12692 and #12701 to adjust access patterns in the python transpiler
path to avoid eagerly creating a Python space operation object. The goal
of this PR is to mitigate the performance regression introduced by the
extra conversion cost of #12459 on the BasisTranslator.
This commit is contained in:
Matthew Treinish 2024-07-10 09:48:08 -04:00 committed by GitHub
parent fa3d6df04f
commit 1e8205e43d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 26 deletions

View File

@ -270,6 +270,11 @@ impl DAGOpNode {
self.instruction.params.to_object(py)
}
#[setter]
fn set_params(&mut self, val: smallvec::SmallVec<[crate::operations::Param; 3]>) {
self.instruction.params = val;
}
pub fn is_parameterized(&self) -> bool {
self.instruction.is_parameterized()
}

View File

@ -30,11 +30,13 @@ from qiskit.circuit import (
QuantumCircuit,
ParameterExpression,
)
from qiskit.dagcircuit import DAGCircuit
from qiskit.dagcircuit import DAGCircuit, DAGOpNode
from qiskit.converters import circuit_to_dag, dag_to_circuit
from qiskit.circuit.equivalence import Key, NodeData
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.exceptions import TranspilerError
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
from qiskit._accelerate.circuit import StandardGate
logger = logging.getLogger(__name__)
@ -253,7 +255,7 @@ class BasisTranslator(TransformationPass):
node_qargs = tuple(wire_map[bit] for bit in node.qargs)
qubit_set = frozenset(node_qargs)
if node.name in target_basis or len(node.qargs) < self._min_qubits:
if isinstance(node.op, ControlFlowOp):
if node.name in CONTROL_FLOW_OP_NAMES:
flow_blocks = []
for block in node.op.blocks:
dag_block = circuit_to_dag(block)
@ -281,7 +283,7 @@ class BasisTranslator(TransformationPass):
continue
if qubit_set in extra_instr_map:
self._replace_node(dag, node, extra_instr_map[qubit_set])
elif (node.op.name, node.op.num_qubits) in instr_map:
elif (node.name, node.num_qubits) in instr_map:
self._replace_node(dag, node, instr_map)
else:
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
@ -298,20 +300,29 @@ class BasisTranslator(TransformationPass):
return dag
def _replace_node(self, dag, node, instr_map):
target_params, target_dag = instr_map[node.op.name, node.op.num_qubits]
if len(node.op.params) != len(target_params):
target_params, target_dag = instr_map[node.name, node.num_qubits]
if len(node.params) != len(target_params):
raise TranspilerError(
"Translation num_params not equal to op num_params."
f"Op: {node.op.params} {node.op.name} Translation: {target_params}\n{target_dag}"
f"Op: {node.params} {node.name} Translation: {target_params}\n{target_dag}"
)
if node.op.params:
parameter_map = dict(zip(target_params, node.op.params))
if node.params:
parameter_map = dict(zip(target_params, node.params))
bound_target_dag = target_dag.copy_empty_like()
for inner_node in target_dag.topological_op_nodes():
if any(isinstance(x, ParameterExpression) for x in inner_node.op.params):
new_op = inner_node._raw_op
if not isinstance(inner_node._raw_op, StandardGate):
new_op = inner_node.op.copy()
new_node = DAGOpNode(
new_op,
qargs=inner_node.qargs,
cargs=inner_node.cargs,
params=inner_node.params,
dag=bound_target_dag,
)
if any(isinstance(x, ParameterExpression) for x in inner_node.params):
new_params = []
for param in new_op.params:
for param in new_node.params:
if not isinstance(param, ParameterExpression):
new_params.append(param)
else:
@ -325,10 +336,10 @@ class BasisTranslator(TransformationPass):
if not new_value.parameters:
new_value = new_value.numeric()
new_params.append(new_value)
new_op.params = new_params
else:
new_op = inner_node.op
bound_target_dag.apply_operation_back(new_op, inner_node.qargs, inner_node.cargs)
new_node.params = new_params
if not isinstance(new_op, StandardGate):
new_op.params = new_params
bound_target_dag._apply_op_node_back(new_node)
if isinstance(target_dag.global_phase, ParameterExpression):
old_phase = target_dag.global_phase
bind_dict = {x: parameter_map[x] for x in old_phase.parameters}
@ -353,7 +364,7 @@ class BasisTranslator(TransformationPass):
dag_op = bound_target_dag.op_nodes()[0].op
# dag_op may be the same instance as other ops in the dag,
# so if there is a condition, need to copy
if getattr(node.op, "condition", None):
if getattr(node, "condition", None):
dag_op = dag_op.copy()
dag.substitute_node(node, dag_op, inplace=True)
@ -370,8 +381,8 @@ class BasisTranslator(TransformationPass):
def _(self, dag: DAGCircuit):
for node in dag.op_nodes():
if not dag.has_calibration_for(node) and len(node.qargs) >= self._min_qubits:
yield (node.name, node.op.num_qubits)
if isinstance(node.op, ControlFlowOp):
yield (node.name, node.num_qubits)
if node.name in CONTROL_FLOW_OP_NAMES:
for block in node.op.blocks:
yield from self._extract_basis(block)
@ -412,10 +423,10 @@ class BasisTranslator(TransformationPass):
frozenset(qargs).issuperset(incomplete_qargs)
for incomplete_qargs in self._qargs_with_non_global_operation
):
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits))
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.num_qubits))
else:
source_basis.add((node.name, node.op.num_qubits))
if isinstance(node.op, ControlFlowOp):
source_basis.add((node.name, node.num_qubits))
if node.name in CONTROL_FLOW_OP_NAMES:
for block in node.op.blocks:
block_dag = circuit_to_dag(block)
source_basis, qargs_local_source_basis = self._extract_basis_target(
@ -628,7 +639,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
doomed_nodes = [
node
for node in dag.op_nodes()
if (node.op.name, node.op.num_qubits) == (gate_name, gate_num_qubits)
if (node.name, node.num_qubits) == (gate_name, gate_num_qubits)
]
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
@ -642,9 +653,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
for node in doomed_nodes:
replacement = equiv.assign_parameters(
dict(zip_longest(equiv_params, node.op.params))
)
replacement = equiv.assign_parameters(dict(zip_longest(equiv_params, node.params)))
replacement_dag = circuit_to_dag(replacement)
@ -666,8 +675,8 @@ def _get_example_gates(source_dag):
def recurse(dag, example_gates=None):
example_gates = example_gates or {}
for node in dag.op_nodes():
example_gates[(node.op.name, node.op.num_qubits)] = node.op
if isinstance(node.op, ControlFlowOp):
example_gates[(node.name, node.num_qubits)] = node
if node.name in CONTROL_FLOW_OP_NAMES:
for block in node.op.blocks:
example_gates = recurse(circuit_to_dag(block), example_gates)
return example_gates