diff --git a/crates/circuit/src/dag_node.rs b/crates/circuit/src/dag_node.rs index 55a40c83dc..f347ec72c8 100644 --- a/crates/circuit/src/dag_node.rs +++ b/crates/circuit/src/dag_node.rs @@ -270,6 +270,11 @@ impl DAGOpNode { self.instruction.params.to_object(py) } + #[setter] + fn set_params(&mut self, val: smallvec::SmallVec<[crate::operations::Param; 3]>) { + self.instruction.params = val; + } + pub fn is_parameterized(&self) -> bool { self.instruction.is_parameterized() } diff --git a/qiskit/transpiler/passes/basis/basis_translator.py b/qiskit/transpiler/passes/basis/basis_translator.py index f2e752dd94..e69887a3b9 100644 --- a/qiskit/transpiler/passes/basis/basis_translator.py +++ b/qiskit/transpiler/passes/basis/basis_translator.py @@ -30,11 +30,13 @@ from qiskit.circuit import ( QuantumCircuit, ParameterExpression, ) -from qiskit.dagcircuit import DAGCircuit +from qiskit.dagcircuit import DAGCircuit, DAGOpNode from qiskit.converters import circuit_to_dag, dag_to_circuit from qiskit.circuit.equivalence import Key, NodeData from qiskit.transpiler.basepasses import TransformationPass from qiskit.transpiler.exceptions import TranspilerError +from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES +from qiskit._accelerate.circuit import StandardGate logger = logging.getLogger(__name__) @@ -253,7 +255,7 @@ class BasisTranslator(TransformationPass): node_qargs = tuple(wire_map[bit] for bit in node.qargs) qubit_set = frozenset(node_qargs) if node.name in target_basis or len(node.qargs) < self._min_qubits: - if isinstance(node.op, ControlFlowOp): + if node.name in CONTROL_FLOW_OP_NAMES: flow_blocks = [] for block in node.op.blocks: dag_block = circuit_to_dag(block) @@ -281,7 +283,7 @@ class BasisTranslator(TransformationPass): continue if qubit_set in extra_instr_map: self._replace_node(dag, node, extra_instr_map[qubit_set]) - elif (node.op.name, node.op.num_qubits) in instr_map: + elif (node.name, node.num_qubits) in instr_map: self._replace_node(dag, node, instr_map) else: raise TranspilerError(f"BasisTranslator did not map {node.name}.") @@ -298,20 +300,29 @@ class BasisTranslator(TransformationPass): return dag def _replace_node(self, dag, node, instr_map): - target_params, target_dag = instr_map[node.op.name, node.op.num_qubits] - if len(node.op.params) != len(target_params): + target_params, target_dag = instr_map[node.name, node.num_qubits] + if len(node.params) != len(target_params): raise TranspilerError( "Translation num_params not equal to op num_params." - f"Op: {node.op.params} {node.op.name} Translation: {target_params}\n{target_dag}" + f"Op: {node.params} {node.name} Translation: {target_params}\n{target_dag}" ) - if node.op.params: - parameter_map = dict(zip(target_params, node.op.params)) + if node.params: + parameter_map = dict(zip(target_params, node.params)) bound_target_dag = target_dag.copy_empty_like() for inner_node in target_dag.topological_op_nodes(): - if any(isinstance(x, ParameterExpression) for x in inner_node.op.params): + new_op = inner_node._raw_op + if not isinstance(inner_node._raw_op, StandardGate): new_op = inner_node.op.copy() + new_node = DAGOpNode( + new_op, + qargs=inner_node.qargs, + cargs=inner_node.cargs, + params=inner_node.params, + dag=bound_target_dag, + ) + if any(isinstance(x, ParameterExpression) for x in inner_node.params): new_params = [] - for param in new_op.params: + for param in new_node.params: if not isinstance(param, ParameterExpression): new_params.append(param) else: @@ -325,10 +336,10 @@ class BasisTranslator(TransformationPass): if not new_value.parameters: new_value = new_value.numeric() new_params.append(new_value) - new_op.params = new_params - else: - new_op = inner_node.op - bound_target_dag.apply_operation_back(new_op, inner_node.qargs, inner_node.cargs) + new_node.params = new_params + if not isinstance(new_op, StandardGate): + new_op.params = new_params + bound_target_dag._apply_op_node_back(new_node) if isinstance(target_dag.global_phase, ParameterExpression): old_phase = target_dag.global_phase bind_dict = {x: parameter_map[x] for x in old_phase.parameters} @@ -353,7 +364,7 @@ class BasisTranslator(TransformationPass): dag_op = bound_target_dag.op_nodes()[0].op # dag_op may be the same instance as other ops in the dag, # so if there is a condition, need to copy - if getattr(node.op, "condition", None): + if getattr(node, "condition", None): dag_op = dag_op.copy() dag.substitute_node(node, dag_op, inplace=True) @@ -370,8 +381,8 @@ class BasisTranslator(TransformationPass): def _(self, dag: DAGCircuit): for node in dag.op_nodes(): if not dag.has_calibration_for(node) and len(node.qargs) >= self._min_qubits: - yield (node.name, node.op.num_qubits) - if isinstance(node.op, ControlFlowOp): + yield (node.name, node.num_qubits) + if node.name in CONTROL_FLOW_OP_NAMES: for block in node.op.blocks: yield from self._extract_basis(block) @@ -412,10 +423,10 @@ class BasisTranslator(TransformationPass): frozenset(qargs).issuperset(incomplete_qargs) for incomplete_qargs in self._qargs_with_non_global_operation ): - qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits)) + qargs_local_source_basis[frozenset(qargs)].add((node.name, node.num_qubits)) else: - source_basis.add((node.name, node.op.num_qubits)) - if isinstance(node.op, ControlFlowOp): + source_basis.add((node.name, node.num_qubits)) + if node.name in CONTROL_FLOW_OP_NAMES: for block in node.op.blocks: block_dag = circuit_to_dag(block) source_basis, qargs_local_source_basis = self._extract_basis_target( @@ -628,7 +639,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag): doomed_nodes = [ node for node in dag.op_nodes() - if (node.op.name, node.op.num_qubits) == (gate_name, gate_num_qubits) + if (node.name, node.num_qubits) == (gate_name, gate_num_qubits) ] if doomed_nodes and logger.isEnabledFor(logging.DEBUG): @@ -642,9 +653,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag): for node in doomed_nodes: - replacement = equiv.assign_parameters( - dict(zip_longest(equiv_params, node.op.params)) - ) + replacement = equiv.assign_parameters(dict(zip_longest(equiv_params, node.params))) replacement_dag = circuit_to_dag(replacement) @@ -666,8 +675,8 @@ def _get_example_gates(source_dag): def recurse(dag, example_gates=None): example_gates = example_gates or {} for node in dag.op_nodes(): - example_gates[(node.op.name, node.op.num_qubits)] = node.op - if isinstance(node.op, ControlFlowOp): + example_gates[(node.name, node.num_qubits)] = node + if node.name in CONTROL_FLOW_OP_NAMES: for block in node.op.blocks: example_gates = recurse(circuit_to_dag(block), example_gates) return example_gates