mirror of https://github.com/Qiskit/qiskit.git
Avoid Python op creation in BasisTranslator (#12705)
This commit updates the BasisTranslator transpiler pass. It builds off of #12692 and #12701 to adjust access patterns in the python transpiler path to avoid eagerly creating a Python space operation object. The goal of this PR is to mitigate the performance regression introduced by the extra conversion cost of #12459 on the BasisTranslator.
This commit is contained in:
parent
fa3d6df04f
commit
1e8205e43d
|
@ -270,6 +270,11 @@ impl DAGOpNode {
|
|||
self.instruction.params.to_object(py)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_params(&mut self, val: smallvec::SmallVec<[crate::operations::Param; 3]>) {
|
||||
self.instruction.params = val;
|
||||
}
|
||||
|
||||
pub fn is_parameterized(&self) -> bool {
|
||||
self.instruction.is_parameterized()
|
||||
}
|
||||
|
|
|
@ -30,11 +30,13 @@ from qiskit.circuit import (
|
|||
QuantumCircuit,
|
||||
ParameterExpression,
|
||||
)
|
||||
from qiskit.dagcircuit import DAGCircuit
|
||||
from qiskit.dagcircuit import DAGCircuit, DAGOpNode
|
||||
from qiskit.converters import circuit_to_dag, dag_to_circuit
|
||||
from qiskit.circuit.equivalence import Key, NodeData
|
||||
from qiskit.transpiler.basepasses import TransformationPass
|
||||
from qiskit.transpiler.exceptions import TranspilerError
|
||||
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
|
||||
from qiskit._accelerate.circuit import StandardGate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -253,7 +255,7 @@ class BasisTranslator(TransformationPass):
|
|||
node_qargs = tuple(wire_map[bit] for bit in node.qargs)
|
||||
qubit_set = frozenset(node_qargs)
|
||||
if node.name in target_basis or len(node.qargs) < self._min_qubits:
|
||||
if isinstance(node.op, ControlFlowOp):
|
||||
if node.name in CONTROL_FLOW_OP_NAMES:
|
||||
flow_blocks = []
|
||||
for block in node.op.blocks:
|
||||
dag_block = circuit_to_dag(block)
|
||||
|
@ -281,7 +283,7 @@ class BasisTranslator(TransformationPass):
|
|||
continue
|
||||
if qubit_set in extra_instr_map:
|
||||
self._replace_node(dag, node, extra_instr_map[qubit_set])
|
||||
elif (node.op.name, node.op.num_qubits) in instr_map:
|
||||
elif (node.name, node.num_qubits) in instr_map:
|
||||
self._replace_node(dag, node, instr_map)
|
||||
else:
|
||||
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
|
||||
|
@ -298,20 +300,29 @@ class BasisTranslator(TransformationPass):
|
|||
return dag
|
||||
|
||||
def _replace_node(self, dag, node, instr_map):
|
||||
target_params, target_dag = instr_map[node.op.name, node.op.num_qubits]
|
||||
if len(node.op.params) != len(target_params):
|
||||
target_params, target_dag = instr_map[node.name, node.num_qubits]
|
||||
if len(node.params) != len(target_params):
|
||||
raise TranspilerError(
|
||||
"Translation num_params not equal to op num_params."
|
||||
f"Op: {node.op.params} {node.op.name} Translation: {target_params}\n{target_dag}"
|
||||
f"Op: {node.params} {node.name} Translation: {target_params}\n{target_dag}"
|
||||
)
|
||||
if node.op.params:
|
||||
parameter_map = dict(zip(target_params, node.op.params))
|
||||
if node.params:
|
||||
parameter_map = dict(zip(target_params, node.params))
|
||||
bound_target_dag = target_dag.copy_empty_like()
|
||||
for inner_node in target_dag.topological_op_nodes():
|
||||
if any(isinstance(x, ParameterExpression) for x in inner_node.op.params):
|
||||
new_op = inner_node._raw_op
|
||||
if not isinstance(inner_node._raw_op, StandardGate):
|
||||
new_op = inner_node.op.copy()
|
||||
new_node = DAGOpNode(
|
||||
new_op,
|
||||
qargs=inner_node.qargs,
|
||||
cargs=inner_node.cargs,
|
||||
params=inner_node.params,
|
||||
dag=bound_target_dag,
|
||||
)
|
||||
if any(isinstance(x, ParameterExpression) for x in inner_node.params):
|
||||
new_params = []
|
||||
for param in new_op.params:
|
||||
for param in new_node.params:
|
||||
if not isinstance(param, ParameterExpression):
|
||||
new_params.append(param)
|
||||
else:
|
||||
|
@ -325,10 +336,10 @@ class BasisTranslator(TransformationPass):
|
|||
if not new_value.parameters:
|
||||
new_value = new_value.numeric()
|
||||
new_params.append(new_value)
|
||||
new_node.params = new_params
|
||||
if not isinstance(new_op, StandardGate):
|
||||
new_op.params = new_params
|
||||
else:
|
||||
new_op = inner_node.op
|
||||
bound_target_dag.apply_operation_back(new_op, inner_node.qargs, inner_node.cargs)
|
||||
bound_target_dag._apply_op_node_back(new_node)
|
||||
if isinstance(target_dag.global_phase, ParameterExpression):
|
||||
old_phase = target_dag.global_phase
|
||||
bind_dict = {x: parameter_map[x] for x in old_phase.parameters}
|
||||
|
@ -353,7 +364,7 @@ class BasisTranslator(TransformationPass):
|
|||
dag_op = bound_target_dag.op_nodes()[0].op
|
||||
# dag_op may be the same instance as other ops in the dag,
|
||||
# so if there is a condition, need to copy
|
||||
if getattr(node.op, "condition", None):
|
||||
if getattr(node, "condition", None):
|
||||
dag_op = dag_op.copy()
|
||||
dag.substitute_node(node, dag_op, inplace=True)
|
||||
|
||||
|
@ -370,8 +381,8 @@ class BasisTranslator(TransformationPass):
|
|||
def _(self, dag: DAGCircuit):
|
||||
for node in dag.op_nodes():
|
||||
if not dag.has_calibration_for(node) and len(node.qargs) >= self._min_qubits:
|
||||
yield (node.name, node.op.num_qubits)
|
||||
if isinstance(node.op, ControlFlowOp):
|
||||
yield (node.name, node.num_qubits)
|
||||
if node.name in CONTROL_FLOW_OP_NAMES:
|
||||
for block in node.op.blocks:
|
||||
yield from self._extract_basis(block)
|
||||
|
||||
|
@ -412,10 +423,10 @@ class BasisTranslator(TransformationPass):
|
|||
frozenset(qargs).issuperset(incomplete_qargs)
|
||||
for incomplete_qargs in self._qargs_with_non_global_operation
|
||||
):
|
||||
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits))
|
||||
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.num_qubits))
|
||||
else:
|
||||
source_basis.add((node.name, node.op.num_qubits))
|
||||
if isinstance(node.op, ControlFlowOp):
|
||||
source_basis.add((node.name, node.num_qubits))
|
||||
if node.name in CONTROL_FLOW_OP_NAMES:
|
||||
for block in node.op.blocks:
|
||||
block_dag = circuit_to_dag(block)
|
||||
source_basis, qargs_local_source_basis = self._extract_basis_target(
|
||||
|
@ -628,7 +639,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
|
|||
doomed_nodes = [
|
||||
node
|
||||
for node in dag.op_nodes()
|
||||
if (node.op.name, node.op.num_qubits) == (gate_name, gate_num_qubits)
|
||||
if (node.name, node.num_qubits) == (gate_name, gate_num_qubits)
|
||||
]
|
||||
|
||||
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
|
||||
|
@ -642,9 +653,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
|
|||
|
||||
for node in doomed_nodes:
|
||||
|
||||
replacement = equiv.assign_parameters(
|
||||
dict(zip_longest(equiv_params, node.op.params))
|
||||
)
|
||||
replacement = equiv.assign_parameters(dict(zip_longest(equiv_params, node.params)))
|
||||
|
||||
replacement_dag = circuit_to_dag(replacement)
|
||||
|
||||
|
@ -666,8 +675,8 @@ def _get_example_gates(source_dag):
|
|||
def recurse(dag, example_gates=None):
|
||||
example_gates = example_gates or {}
|
||||
for node in dag.op_nodes():
|
||||
example_gates[(node.op.name, node.op.num_qubits)] = node.op
|
||||
if isinstance(node.op, ControlFlowOp):
|
||||
example_gates[(node.name, node.num_qubits)] = node
|
||||
if node.name in CONTROL_FLOW_OP_NAMES:
|
||||
for block in node.op.blocks:
|
||||
example_gates = recurse(circuit_to_dag(block), example_gates)
|
||||
return example_gates
|
||||
|
|
Loading…
Reference in New Issue