mirror of https://github.com/Qiskit/qiskit.git
support control flow in BasisTranslator pass (#7808)
* create two tests and 1st modification of unroller * if_else test, parameter test * black * linting * change shallow copy of control flow ops to not copy body * add special copy * debug * add `replace_blocks` method * minor update * clean debug code * linting fix bugs * minor commit * linting * don't recurse on run * linting * don't mutate basis in "_update_basis". * Update qiskit/transpiler/passes/basis/basis_translator.py Co-authored-by: Jake Lishman <jake@binhbar.com> * apply_translation returns bool * factor out "replace_node" function * linting * singledispatchmethod -> singledispatch for python 3.7 * black * fix indentation bug * linting * black * changed _get_example_gates following @jakelishman suggestion. Co-authored-by: Jake Lishman <jake@binhbar.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
4414c4eca6
commit
f28f3835fb
|
@ -97,8 +97,6 @@ class WhileLoopOp(ControlFlowOp):
|
||||||
|
|
||||||
def replace_blocks(self, blocks):
|
def replace_blocks(self, blocks):
|
||||||
(body,) = blocks
|
(body,) = blocks
|
||||||
if not isinstance(body, QuantumCircuit):
|
|
||||||
raise CircuitError("WhileLoopOp expects a single QuantumCircuit when setting blocks")
|
|
||||||
return WhileLoopOp(self.condition, body, label=self.label)
|
return WhileLoopOp(self.condition, body, label=self.label)
|
||||||
|
|
||||||
def c_if(self, classical, val):
|
def c_if(self, classical, val):
|
||||||
|
|
|
@ -362,6 +362,23 @@ class QuantumCircuit:
|
||||||
"""
|
"""
|
||||||
self._calibrations = defaultdict(dict, calibrations)
|
self._calibrations = defaultdict(dict, calibrations)
|
||||||
|
|
||||||
|
def has_calibration_for(self, instr_context: Tuple):
|
||||||
|
"""Return True if the circuit has a calibration defined for the instruction context. In this
|
||||||
|
case, the operation does not need to be translated to the device basis.
|
||||||
|
"""
|
||||||
|
instr, qargs, _ = instr_context
|
||||||
|
if not self.calibrations or instr.name not in self.calibrations:
|
||||||
|
return False
|
||||||
|
qubits = tuple(self.qubits.index(qubit) for qubit in qargs)
|
||||||
|
params = []
|
||||||
|
for p in instr.params:
|
||||||
|
if isinstance(p, ParameterExpression) and not p.parameters:
|
||||||
|
params.append(float(p))
|
||||||
|
else:
|
||||||
|
params.append(p)
|
||||||
|
params = tuple(params)
|
||||||
|
return (qubits, params) in self.calibrations[instr.name]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metadata(self) -> dict:
|
def metadata(self) -> dict:
|
||||||
"""The user provided metadata associated with the circuit
|
"""The user provided metadata associated with the circuit
|
||||||
|
|
|
@ -964,7 +964,6 @@ class DAGCircuit:
|
||||||
# Try to convert to float, but in case of unbound ParameterExpressions
|
# Try to convert to float, but in case of unbound ParameterExpressions
|
||||||
# a TypeError will be raise, fallback to normal equality in those
|
# a TypeError will be raise, fallback to normal equality in those
|
||||||
# cases
|
# cases
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self_phase = float(self.global_phase)
|
self_phase = float(self.global_phase)
|
||||||
other_phase = float(other.global_phase)
|
other_phase = float(other.global_phase)
|
||||||
|
|
|
@ -17,12 +17,14 @@ import logging
|
||||||
|
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from functools import singledispatch
|
||||||
|
|
||||||
import retworkx
|
import retworkx
|
||||||
|
|
||||||
from qiskit.circuit import Gate, ParameterVector, QuantumRegister
|
from qiskit.circuit import Gate, ParameterVector, QuantumRegister, ControlFlowOp, QuantumCircuit
|
||||||
from qiskit.circuit.equivalence import Key
|
|
||||||
from qiskit.dagcircuit import DAGCircuit
|
from qiskit.dagcircuit import DAGCircuit
|
||||||
|
from qiskit.converters import circuit_to_dag, dag_to_circuit
|
||||||
|
from qiskit.circuit.equivalence import Key
|
||||||
from qiskit.transpiler.basepasses import TransformationPass
|
from qiskit.transpiler.basepasses import TransformationPass
|
||||||
from qiskit.transpiler.exceptions import TranspilerError
|
from qiskit.transpiler.exceptions import TranspilerError
|
||||||
|
|
||||||
|
@ -119,36 +121,12 @@ class BasisTranslator(TransformationPass):
|
||||||
if self._target is None:
|
if self._target is None:
|
||||||
basic_instrs = ["measure", "reset", "barrier", "snapshot", "delay"]
|
basic_instrs = ["measure", "reset", "barrier", "snapshot", "delay"]
|
||||||
target_basis = set(self._target_basis)
|
target_basis = set(self._target_basis)
|
||||||
source_basis = set()
|
source_basis = set(_extract_basis(dag))
|
||||||
for node in dag.op_nodes():
|
|
||||||
if not dag.has_calibration_for(node):
|
|
||||||
source_basis.add((node.name, node.op.num_qubits))
|
|
||||||
qargs_local_source_basis = {}
|
qargs_local_source_basis = {}
|
||||||
else:
|
else:
|
||||||
basic_instrs = ["barrier", "snapshot"]
|
basic_instrs = ["barrier", "snapshot"]
|
||||||
source_basis = set()
|
|
||||||
target_basis = self._target.keys() - set(self._non_global_operations)
|
target_basis = self._target.keys() - set(self._non_global_operations)
|
||||||
qargs_local_source_basis = defaultdict(set)
|
source_basis, qargs_local_source_basis = self._extract_basis_target(dag, qarg_indices)
|
||||||
for node in dag.op_nodes():
|
|
||||||
qargs = tuple(qarg_indices[bit] for bit in node.qargs)
|
|
||||||
if dag.has_calibration_for(node):
|
|
||||||
continue
|
|
||||||
# Treat the instruction as on an incomplete basis if the qargs are in the
|
|
||||||
# qargs_with_non_global_operation dictionary or if any of the qubits in qargs
|
|
||||||
# are a superset for a non-local operation. For example, if the qargs
|
|
||||||
# are (0, 1) and that's a global (ie no non-local operations on (0, 1)
|
|
||||||
# operation but there is a non-local operation on (1,) we need to
|
|
||||||
# do an extra non-local search for this op to ensure we include any
|
|
||||||
# single qubit operation for (1,) as valid. This pattern also holds
|
|
||||||
# true for > 2q ops too (so for 4q operations we need to check for 3q, 2q,
|
|
||||||
# and 1q operations in the same manner)
|
|
||||||
if qargs in self._qargs_with_non_global_operation or any(
|
|
||||||
frozenset(qargs).issuperset(incomplete_qargs)
|
|
||||||
for incomplete_qargs in self._qargs_with_non_global_operation
|
|
||||||
):
|
|
||||||
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits))
|
|
||||||
else:
|
|
||||||
source_basis.add((node.name, node.op.num_qubits))
|
|
||||||
|
|
||||||
target_basis = set(target_basis).union(basic_instrs)
|
target_basis = set(target_basis).union(basic_instrs)
|
||||||
|
|
||||||
|
@ -225,68 +203,43 @@ class BasisTranslator(TransformationPass):
|
||||||
# Replace source instructions with target translations.
|
# Replace source instructions with target translations.
|
||||||
|
|
||||||
replace_start_time = time.time()
|
replace_start_time = time.time()
|
||||||
for node in dag.op_nodes():
|
|
||||||
node_qargs = tuple(qarg_indices[bit] for bit in node.qargs)
|
|
||||||
qubit_set = frozenset(node_qargs)
|
|
||||||
|
|
||||||
if node.name in target_basis:
|
def apply_translation(dag):
|
||||||
continue
|
dag_updated = False
|
||||||
if (
|
for node in dag.op_nodes():
|
||||||
node_qargs in self._qargs_with_non_global_operation
|
node_qargs = tuple(qarg_indices[bit] for bit in node.qargs)
|
||||||
and node.name in self._qargs_with_non_global_operation[node_qargs]
|
qubit_set = frozenset(node_qargs)
|
||||||
):
|
if node.name in target_basis:
|
||||||
continue
|
if isinstance(node.op, ControlFlowOp):
|
||||||
|
flow_blocks = []
|
||||||
|
for block in node.op.blocks:
|
||||||
|
dag_block = circuit_to_dag(block)
|
||||||
|
dag_updated = apply_translation(dag_block)
|
||||||
|
if dag_updated:
|
||||||
|
flow_circ_block = dag_to_circuit(dag_block)
|
||||||
|
else:
|
||||||
|
flow_circ_block = block
|
||||||
|
flow_blocks.append(flow_circ_block)
|
||||||
|
node.op = node.op.replace_blocks(flow_blocks)
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
node_qargs in self._qargs_with_non_global_operation
|
||||||
|
and node.name in self._qargs_with_non_global_operation[node_qargs]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
if dag.has_calibration_for(node):
|
if dag.has_calibration_for(node):
|
||||||
continue
|
continue
|
||||||
|
if qubit_set in extra_instr_map:
|
||||||
def replace_node(node, instr_map):
|
self._replace_node(dag, node, extra_instr_map[qubit_set])
|
||||||
target_params, target_dag = instr_map[node.op.name, node.op.num_qubits]
|
elif (node.op.name, node.op.num_qubits) in instr_map:
|
||||||
if len(node.op.params) != len(target_params):
|
self._replace_node(dag, node, instr_map)
|
||||||
raise TranspilerError(
|
|
||||||
"Translation num_params not equal to op num_params."
|
|
||||||
"Op: {} {} Translation: {}\n{}".format(
|
|
||||||
node.op.params, node.op.name, target_params, target_dag
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if node.op.params:
|
|
||||||
# Convert target to circ and back to assign_parameters, since
|
|
||||||
# DAGCircuits won't have a ParameterTable.
|
|
||||||
from qiskit.converters import dag_to_circuit, circuit_to_dag
|
|
||||||
|
|
||||||
target_circuit = dag_to_circuit(target_dag)
|
|
||||||
|
|
||||||
target_circuit.assign_parameters(
|
|
||||||
dict(zip_longest(target_params, node.op.params)), inplace=True
|
|
||||||
)
|
|
||||||
|
|
||||||
bound_target_dag = circuit_to_dag(target_circuit)
|
|
||||||
else:
|
else:
|
||||||
bound_target_dag = target_dag
|
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
|
||||||
|
dag_updated = True
|
||||||
if len(bound_target_dag.op_nodes()) == 1 and len(
|
return dag_updated
|
||||||
bound_target_dag.op_nodes()[0].qargs
|
|
||||||
) == len(node.qargs):
|
|
||||||
dag_op = bound_target_dag.op_nodes()[0].op
|
|
||||||
# dag_op may be the same instance as other ops in the dag,
|
|
||||||
# so if there is a condition, need to copy
|
|
||||||
if node.op.condition:
|
|
||||||
dag_op = dag_op.copy()
|
|
||||||
dag.substitute_node(node, dag_op, inplace=True)
|
|
||||||
|
|
||||||
if bound_target_dag.global_phase:
|
|
||||||
dag.global_phase += bound_target_dag.global_phase
|
|
||||||
else:
|
|
||||||
dag.substitute_node_with_dag(node, bound_target_dag)
|
|
||||||
|
|
||||||
if qubit_set in extra_instr_map:
|
|
||||||
replace_node(node, extra_instr_map[qubit_set])
|
|
||||||
elif (node.op.name, node.op.num_qubits) in instr_map:
|
|
||||||
replace_node(node, instr_map)
|
|
||||||
else:
|
|
||||||
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
|
|
||||||
|
|
||||||
|
apply_translation(dag)
|
||||||
replace_end_time = time.time()
|
replace_end_time = time.time()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Basis translation instructions replaced in %.3fs.",
|
"Basis translation instructions replaced in %.3fs.",
|
||||||
|
@ -295,6 +248,110 @@ class BasisTranslator(TransformationPass):
|
||||||
|
|
||||||
return dag
|
return dag
|
||||||
|
|
||||||
|
def _replace_node(self, dag, node, instr_map):
|
||||||
|
target_params, target_dag = instr_map[node.op.name, node.op.num_qubits]
|
||||||
|
if len(node.op.params) != len(target_params):
|
||||||
|
raise TranspilerError(
|
||||||
|
"Translation num_params not equal to op num_params."
|
||||||
|
"Op: {} {} Translation: {}\n{}".format(
|
||||||
|
node.op.params, node.op.name, target_params, target_dag
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if node.op.params:
|
||||||
|
# Convert target to circ and back to assign_parameters, since
|
||||||
|
# DAGCircuits won't have a ParameterTable.
|
||||||
|
target_circuit = dag_to_circuit(target_dag)
|
||||||
|
|
||||||
|
target_circuit.assign_parameters(
|
||||||
|
dict(zip_longest(target_params, node.op.params)), inplace=True
|
||||||
|
)
|
||||||
|
|
||||||
|
bound_target_dag = circuit_to_dag(target_circuit)
|
||||||
|
else:
|
||||||
|
bound_target_dag = target_dag
|
||||||
|
|
||||||
|
if len(bound_target_dag.op_nodes()) == 1 and len(
|
||||||
|
bound_target_dag.op_nodes()[0].qargs
|
||||||
|
) == len(node.qargs):
|
||||||
|
dag_op = bound_target_dag.op_nodes()[0].op
|
||||||
|
# dag_op may be the same instance as other ops in the dag,
|
||||||
|
# so if there is a condition, need to copy
|
||||||
|
if node.op.condition:
|
||||||
|
dag_op = dag_op.copy()
|
||||||
|
dag.substitute_node(node, dag_op, inplace=True)
|
||||||
|
|
||||||
|
if bound_target_dag.global_phase:
|
||||||
|
dag.global_phase += bound_target_dag.global_phase
|
||||||
|
else:
|
||||||
|
dag.substitute_node_with_dag(node, bound_target_dag)
|
||||||
|
|
||||||
|
def _extract_basis_target(
|
||||||
|
self, dag, qarg_indices, source_basis=None, qargs_local_source_basis=None
|
||||||
|
):
|
||||||
|
if source_basis is None:
|
||||||
|
source_basis = set()
|
||||||
|
if qargs_local_source_basis is None:
|
||||||
|
qargs_local_source_basis = defaultdict(set)
|
||||||
|
for node in dag.op_nodes():
|
||||||
|
qargs = tuple(qarg_indices[bit] for bit in node.qargs)
|
||||||
|
if dag.has_calibration_for(node):
|
||||||
|
continue
|
||||||
|
# Treat the instruction as on an incomplete basis if the qargs are in the
|
||||||
|
# qargs_with_non_global_operation dictionary or if any of the qubits in qargs
|
||||||
|
# are a superset for a non-local operation. For example, if the qargs
|
||||||
|
# are (0, 1) and that's a global (ie no non-local operations on (0, 1)
|
||||||
|
# operation but there is a non-local operation on (1,) we need to
|
||||||
|
# do an extra non-local search for this op to ensure we include any
|
||||||
|
# single qubit operation for (1,) as valid. This pattern also holds
|
||||||
|
# true for > 2q ops too (so for 4q operations we need to check for 3q, 2q,
|
||||||
|
# and 1q operations in the same manner)
|
||||||
|
if qargs in self._qargs_with_non_global_operation or any(
|
||||||
|
frozenset(qargs).issuperset(incomplete_qargs)
|
||||||
|
for incomplete_qargs in self._qargs_with_non_global_operation
|
||||||
|
):
|
||||||
|
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits))
|
||||||
|
else:
|
||||||
|
source_basis.add((node.name, node.op.num_qubits))
|
||||||
|
if isinstance(node.op, ControlFlowOp):
|
||||||
|
for block in node.op.blocks:
|
||||||
|
block_dag = circuit_to_dag(block)
|
||||||
|
source_basis, qargs_local_source_basis = self._extract_basis_target(
|
||||||
|
block_dag,
|
||||||
|
qarg_indices,
|
||||||
|
source_basis=source_basis,
|
||||||
|
qargs_local_source_basis=qargs_local_source_basis,
|
||||||
|
)
|
||||||
|
return source_basis, qargs_local_source_basis
|
||||||
|
|
||||||
|
|
||||||
|
# this could be singledispatchmethod and included in above class when minimum
|
||||||
|
# supported python version=3.8.
|
||||||
|
@singledispatch
|
||||||
|
def _extract_basis(circuit):
|
||||||
|
return circuit
|
||||||
|
|
||||||
|
|
||||||
|
@_extract_basis.register
|
||||||
|
def _(dag: DAGCircuit):
|
||||||
|
for node in dag.op_nodes():
|
||||||
|
if not dag.has_calibration_for(node):
|
||||||
|
yield (node.name, node.op.num_qubits)
|
||||||
|
if isinstance(node.op, ControlFlowOp):
|
||||||
|
for block in node.op.blocks:
|
||||||
|
yield from _extract_basis(block)
|
||||||
|
|
||||||
|
|
||||||
|
@_extract_basis.register
|
||||||
|
def _(circ: QuantumCircuit):
|
||||||
|
for instr_context in circ.data:
|
||||||
|
instr, _, _ = instr_context
|
||||||
|
if not circ.has_calibration_for(instr_context):
|
||||||
|
yield (instr.name, instr.num_qubits)
|
||||||
|
if isinstance(instr, ControlFlowOp):
|
||||||
|
for block in instr.blocks:
|
||||||
|
yield from _extract_basis(block)
|
||||||
|
|
||||||
|
|
||||||
class StopIfBasisRewritable(Exception):
|
class StopIfBasisRewritable(Exception):
|
||||||
"""Custom exception that signals `retworkx.dijkstra_search` to stop."""
|
"""Custom exception that signals `retworkx.dijkstra_search` to stop."""
|
||||||
|
@ -486,8 +543,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
|
||||||
source_basis but not affected by basis_transforms will be included
|
source_basis but not affected by basis_transforms will be included
|
||||||
as a key mapping to itself.
|
as a key mapping to itself.
|
||||||
"""
|
"""
|
||||||
|
example_gates = _get_example_gates(source_dag)
|
||||||
example_gates = {(node.op.name, node.op.num_qubits): node.op for node in source_dag.op_nodes()}
|
|
||||||
mapped_instrs = {}
|
mapped_instrs = {}
|
||||||
|
|
||||||
for gate_name, gate_num_qubits in source_basis:
|
for gate_name, gate_num_qubits in source_basis:
|
||||||
|
@ -523,7 +579,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
|
||||||
]
|
]
|
||||||
|
|
||||||
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
|
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
|
||||||
from qiskit.converters import dag_to_circuit
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Updating transform for mapped instr %s %s from \n%s",
|
"Updating transform for mapped instr %s %s from \n%s",
|
||||||
|
@ -533,7 +588,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
|
||||||
)
|
)
|
||||||
|
|
||||||
for node in doomed_nodes:
|
for node in doomed_nodes:
|
||||||
from qiskit.converters import circuit_to_dag
|
|
||||||
|
|
||||||
replacement = equiv.assign_parameters(
|
replacement = equiv.assign_parameters(
|
||||||
dict(zip_longest(equiv_params, node.op.params))
|
dict(zip_longest(equiv_params, node.op.params))
|
||||||
|
@ -544,7 +598,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
|
||||||
dag.substitute_node_with_dag(node, replacement_dag)
|
dag.substitute_node_with_dag(node, replacement_dag)
|
||||||
|
|
||||||
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
|
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
|
||||||
from qiskit.converters import dag_to_circuit
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Updated transform for mapped instr %s %s to\n%s",
|
"Updated transform for mapped instr %s %s to\n%s",
|
||||||
|
@ -554,3 +607,16 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
|
||||||
)
|
)
|
||||||
|
|
||||||
return mapped_instrs
|
return mapped_instrs
|
||||||
|
|
||||||
|
|
||||||
|
def _get_example_gates(source_dag):
|
||||||
|
def recurse(dag, example_gates=None):
|
||||||
|
example_gates = example_gates or {}
|
||||||
|
for node in dag.op_nodes():
|
||||||
|
example_gates[(node.op.name, node.op.num_qubits)] = node.op
|
||||||
|
if isinstance(node.op, ControlFlowOp):
|
||||||
|
for block in node.op.blocks:
|
||||||
|
example_gates = recurse(circuit_to_dag(block), example_gates)
|
||||||
|
return example_gates
|
||||||
|
|
||||||
|
return recurse(source_dag)
|
||||||
|
|
|
@ -20,7 +20,7 @@ from numpy import pi
|
||||||
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit
|
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit
|
||||||
from qiskit import transpile
|
from qiskit import transpile
|
||||||
from qiskit.test import QiskitTestCase
|
from qiskit.test import QiskitTestCase
|
||||||
from qiskit.circuit import Gate, Parameter, EquivalenceLibrary
|
from qiskit.circuit import Gate, Parameter, EquivalenceLibrary, Qubit, Clbit
|
||||||
from qiskit.circuit.library import (
|
from qiskit.circuit.library import (
|
||||||
U1Gate,
|
U1Gate,
|
||||||
U2Gate,
|
U2Gate,
|
||||||
|
@ -49,15 +49,15 @@ from qiskit.circuit.library.standard_gates.equivalence_library import (
|
||||||
class OneQubitZeroParamGate(Gate):
|
class OneQubitZeroParamGate(Gate):
|
||||||
"""Mock one qubit zero param gate."""
|
"""Mock one qubit zero param gate."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, name="1q0p"):
|
||||||
super().__init__("1q0p", 1, [])
|
super().__init__(name, 1, [])
|
||||||
|
|
||||||
|
|
||||||
class OneQubitOneParamGate(Gate):
|
class OneQubitOneParamGate(Gate):
|
||||||
"""Mock one qubit one param gate."""
|
"""Mock one qubit one param gate."""
|
||||||
|
|
||||||
def __init__(self, theta):
|
def __init__(self, theta, name="1q1p"):
|
||||||
super().__init__("1q1p", 1, [theta])
|
super().__init__(name, 1, [theta])
|
||||||
|
|
||||||
|
|
||||||
class OneQubitOneParamPrimeGate(Gate):
|
class OneQubitOneParamPrimeGate(Gate):
|
||||||
|
@ -70,22 +70,22 @@ class OneQubitOneParamPrimeGate(Gate):
|
||||||
class OneQubitTwoParamGate(Gate):
|
class OneQubitTwoParamGate(Gate):
|
||||||
"""Mock one qubit two param gate."""
|
"""Mock one qubit two param gate."""
|
||||||
|
|
||||||
def __init__(self, phi, lam):
|
def __init__(self, phi, lam, name="1q2p"):
|
||||||
super().__init__("1q2p", 1, [phi, lam])
|
super().__init__(name, 1, [phi, lam])
|
||||||
|
|
||||||
|
|
||||||
class TwoQubitZeroParamGate(Gate):
|
class TwoQubitZeroParamGate(Gate):
|
||||||
"""Mock one qubit zero param gate."""
|
"""Mock one qubit zero param gate."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, name="2q0p"):
|
||||||
super().__init__("2q0p", 2, [])
|
super().__init__(name, 2, [])
|
||||||
|
|
||||||
|
|
||||||
class VariadicZeroParamGate(Gate):
|
class VariadicZeroParamGate(Gate):
|
||||||
"""Mock variadic zero param gate."""
|
"""Mock variadic zero param gate."""
|
||||||
|
|
||||||
def __init__(self, num_qubits):
|
def __init__(self, num_qubits, name="vq0p"):
|
||||||
super().__init__("vq0p", num_qubits, [])
|
super().__init__(name, num_qubits, [])
|
||||||
|
|
||||||
|
|
||||||
class TestBasisTranslator(QiskitTestCase):
|
class TestBasisTranslator(QiskitTestCase):
|
||||||
|
@ -382,6 +382,77 @@ class TestBasisTranslator(QiskitTestCase):
|
||||||
|
|
||||||
self.assertEqual(actual, expected_dag)
|
self.assertEqual(actual, expected_dag)
|
||||||
|
|
||||||
|
def test_if_else(self):
|
||||||
|
"""Test a simple if-else with parameters."""
|
||||||
|
qubits = [Qubit(), Qubit()]
|
||||||
|
clbits = [Clbit(), Clbit()]
|
||||||
|
alpha = Parameter("alpha")
|
||||||
|
beta = Parameter("beta")
|
||||||
|
gate = OneQubitOneParamGate(alpha)
|
||||||
|
equiv = QuantumCircuit([qubits[0]])
|
||||||
|
equiv.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
|
||||||
|
equiv.append(OneQubitOneParamGate(alpha, name="1q1p_2"), [qubits[0]])
|
||||||
|
|
||||||
|
eq_lib = EquivalenceLibrary()
|
||||||
|
eq_lib.add_equivalence(gate, equiv)
|
||||||
|
|
||||||
|
circ = QuantumCircuit(qubits, clbits)
|
||||||
|
circ.append(OneQubitOneParamGate(beta), [qubits[0]])
|
||||||
|
circ.measure(qubits[0], clbits[1])
|
||||||
|
with circ.if_test((clbits[1], 0)) as else_:
|
||||||
|
circ.append(OneQubitOneParamGate(alpha), [qubits[0]])
|
||||||
|
circ.append(TwoQubitZeroParamGate(), qubits)
|
||||||
|
with else_:
|
||||||
|
circ.append(TwoQubitZeroParamGate(), [qubits[1], qubits[0]])
|
||||||
|
dag = circuit_to_dag(circ)
|
||||||
|
dag_translated = BasisTranslator(eq_lib, ["if_else", "1q0p_2", "1q1p_2", "2q0p"]).run(dag)
|
||||||
|
|
||||||
|
expected = QuantumCircuit(qubits, clbits)
|
||||||
|
expected.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
|
||||||
|
expected.append(OneQubitOneParamGate(beta, name="1q1p_2"), [qubits[0]])
|
||||||
|
expected.measure(qubits[0], clbits[1])
|
||||||
|
with expected.if_test((clbits[1], 0)) as else_:
|
||||||
|
expected.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
|
||||||
|
expected.append(OneQubitOneParamGate(alpha, name="1q1p_2"), [qubits[0]])
|
||||||
|
expected.append(TwoQubitZeroParamGate(), qubits)
|
||||||
|
with else_:
|
||||||
|
expected.append(TwoQubitZeroParamGate(), [qubits[1], qubits[0]])
|
||||||
|
dag_expected = circuit_to_dag(expected)
|
||||||
|
self.assertEqual(dag_translated, dag_expected)
|
||||||
|
|
||||||
|
def test_nested_loop(self):
|
||||||
|
"""Test a simple if-else with parameters."""
|
||||||
|
qubits = [Qubit(), Qubit()]
|
||||||
|
clbits = [Clbit(), Clbit()]
|
||||||
|
cr = ClassicalRegister(bits=clbits)
|
||||||
|
index1 = Parameter("index1")
|
||||||
|
alpha = Parameter("alpha")
|
||||||
|
|
||||||
|
gate = OneQubitOneParamGate(alpha)
|
||||||
|
equiv = QuantumCircuit([qubits[0]])
|
||||||
|
equiv.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
|
||||||
|
equiv.append(OneQubitOneParamGate(alpha, name="1q1p_2"), [qubits[0]])
|
||||||
|
|
||||||
|
eq_lib = EquivalenceLibrary()
|
||||||
|
eq_lib.add_equivalence(gate, equiv)
|
||||||
|
|
||||||
|
circ = QuantumCircuit(qubits, cr)
|
||||||
|
with circ.for_loop(range(3), loop_parameter=index1) as ind:
|
||||||
|
with circ.while_loop((cr, 0)):
|
||||||
|
circ.append(OneQubitOneParamGate(alpha * ind), [qubits[0]])
|
||||||
|
dag = circuit_to_dag(circ)
|
||||||
|
dag_translated = BasisTranslator(
|
||||||
|
eq_lib, ["if_else", "for_loop", "while_loop", "1q0p_2", "1q1p_2"]
|
||||||
|
).run(dag)
|
||||||
|
|
||||||
|
expected = QuantumCircuit(qubits, cr)
|
||||||
|
with expected.for_loop(range(3), loop_parameter=index1) as ind:
|
||||||
|
with expected.while_loop((cr, 0)):
|
||||||
|
expected.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
|
||||||
|
expected.append(OneQubitOneParamGate(alpha * ind, name="1q1p_2"), [qubits[0]])
|
||||||
|
dag_expected = circuit_to_dag(expected)
|
||||||
|
self.assertEqual(dag_translated, dag_expected)
|
||||||
|
|
||||||
|
|
||||||
class TestUnrollerCompatability(QiskitTestCase):
|
class TestUnrollerCompatability(QiskitTestCase):
|
||||||
"""Tests backward compatability with the Unroller pass.
|
"""Tests backward compatability with the Unroller pass.
|
||||||
|
|
Loading…
Reference in New Issue