support control flow in BasisTranslator pass (#7808)

* create two tests and 1st modification of unroller

* if_else test, parameter test

* black

* linting

* change shallow copy of control flow ops to not copy body

* add special copy

* debug

* add `replace_blocks` method

* minor update

* clean debug code

* linting fix bugs

* minor commit

* linting

* don't recurse on run

* linting

* don't mutate basis in "_update_basis".

* Update qiskit/transpiler/passes/basis/basis_translator.py

Co-authored-by: Jake Lishman <jake@binhbar.com>

* apply_translation returns bool

* factor out "replace_node" function

* linting

* singledispatchmethod -> singledispatch for python 3.7

* black

* fix indentation bug

* linting

* black

* changed _get_example_gates following @jakelishman suggestion.

Co-authored-by: Jake Lishman <jake@binhbar.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
ewinston 2022-06-21 13:40:00 -04:00 committed by GitHub
parent 4414c4eca6
commit f28f3835fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 256 additions and 105 deletions

View File

@ -97,8 +97,6 @@ class WhileLoopOp(ControlFlowOp):
def replace_blocks(self, blocks):
(body,) = blocks
if not isinstance(body, QuantumCircuit):
raise CircuitError("WhileLoopOp expects a single QuantumCircuit when setting blocks")
return WhileLoopOp(self.condition, body, label=self.label)
def c_if(self, classical, val):

View File

@ -362,6 +362,23 @@ class QuantumCircuit:
"""
self._calibrations = defaultdict(dict, calibrations)
def has_calibration_for(self, instr_context: Tuple):
"""Return True if the circuit has a calibration defined for the instruction context. In this
case, the operation does not need to be translated to the device basis.
"""
instr, qargs, _ = instr_context
if not self.calibrations or instr.name not in self.calibrations:
return False
qubits = tuple(self.qubits.index(qubit) for qubit in qargs)
params = []
for p in instr.params:
if isinstance(p, ParameterExpression) and not p.parameters:
params.append(float(p))
else:
params.append(p)
params = tuple(params)
return (qubits, params) in self.calibrations[instr.name]
@property
def metadata(self) -> dict:
"""The user provided metadata associated with the circuit

View File

@ -964,7 +964,6 @@ class DAGCircuit:
# Try to convert to float, but in case of unbound ParameterExpressions
# a TypeError will be raise, fallback to normal equality in those
# cases
try:
self_phase = float(self.global_phase)
other_phase = float(other.global_phase)

View File

@ -17,12 +17,14 @@ import logging
from itertools import zip_longest
from collections import defaultdict
from functools import singledispatch
import retworkx
from qiskit.circuit import Gate, ParameterVector, QuantumRegister
from qiskit.circuit.equivalence import Key
from qiskit.circuit import Gate, ParameterVector, QuantumRegister, ControlFlowOp, QuantumCircuit
from qiskit.dagcircuit import DAGCircuit
from qiskit.converters import circuit_to_dag, dag_to_circuit
from qiskit.circuit.equivalence import Key
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.exceptions import TranspilerError
@ -119,36 +121,12 @@ class BasisTranslator(TransformationPass):
if self._target is None:
basic_instrs = ["measure", "reset", "barrier", "snapshot", "delay"]
target_basis = set(self._target_basis)
source_basis = set()
for node in dag.op_nodes():
if not dag.has_calibration_for(node):
source_basis.add((node.name, node.op.num_qubits))
source_basis = set(_extract_basis(dag))
qargs_local_source_basis = {}
else:
basic_instrs = ["barrier", "snapshot"]
source_basis = set()
target_basis = self._target.keys() - set(self._non_global_operations)
qargs_local_source_basis = defaultdict(set)
for node in dag.op_nodes():
qargs = tuple(qarg_indices[bit] for bit in node.qargs)
if dag.has_calibration_for(node):
continue
# Treat the instruction as on an incomplete basis if the qargs are in the
# qargs_with_non_global_operation dictionary or if any of the qubits in qargs
# are a superset for a non-local operation. For example, if the qargs
# are (0, 1) and that's a global (ie no non-local operations on (0, 1)
# operation but there is a non-local operation on (1,) we need to
# do an extra non-local search for this op to ensure we include any
# single qubit operation for (1,) as valid. This pattern also holds
# true for > 2q ops too (so for 4q operations we need to check for 3q, 2q,
# and 1q operations in the same manner)
if qargs in self._qargs_with_non_global_operation or any(
frozenset(qargs).issuperset(incomplete_qargs)
for incomplete_qargs in self._qargs_with_non_global_operation
):
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits))
else:
source_basis.add((node.name, node.op.num_qubits))
source_basis, qargs_local_source_basis = self._extract_basis_target(dag, qarg_indices)
target_basis = set(target_basis).union(basic_instrs)
@ -225,68 +203,43 @@ class BasisTranslator(TransformationPass):
# Replace source instructions with target translations.
replace_start_time = time.time()
for node in dag.op_nodes():
node_qargs = tuple(qarg_indices[bit] for bit in node.qargs)
qubit_set = frozenset(node_qargs)
if node.name in target_basis:
continue
if (
node_qargs in self._qargs_with_non_global_operation
and node.name in self._qargs_with_non_global_operation[node_qargs]
):
continue
def apply_translation(dag):
dag_updated = False
for node in dag.op_nodes():
node_qargs = tuple(qarg_indices[bit] for bit in node.qargs)
qubit_set = frozenset(node_qargs)
if node.name in target_basis:
if isinstance(node.op, ControlFlowOp):
flow_blocks = []
for block in node.op.blocks:
dag_block = circuit_to_dag(block)
dag_updated = apply_translation(dag_block)
if dag_updated:
flow_circ_block = dag_to_circuit(dag_block)
else:
flow_circ_block = block
flow_blocks.append(flow_circ_block)
node.op = node.op.replace_blocks(flow_blocks)
continue
if (
node_qargs in self._qargs_with_non_global_operation
and node.name in self._qargs_with_non_global_operation[node_qargs]
):
continue
if dag.has_calibration_for(node):
continue
def replace_node(node, instr_map):
target_params, target_dag = instr_map[node.op.name, node.op.num_qubits]
if len(node.op.params) != len(target_params):
raise TranspilerError(
"Translation num_params not equal to op num_params."
"Op: {} {} Translation: {}\n{}".format(
node.op.params, node.op.name, target_params, target_dag
)
)
if node.op.params:
# Convert target to circ and back to assign_parameters, since
# DAGCircuits won't have a ParameterTable.
from qiskit.converters import dag_to_circuit, circuit_to_dag
target_circuit = dag_to_circuit(target_dag)
target_circuit.assign_parameters(
dict(zip_longest(target_params, node.op.params)), inplace=True
)
bound_target_dag = circuit_to_dag(target_circuit)
if dag.has_calibration_for(node):
continue
if qubit_set in extra_instr_map:
self._replace_node(dag, node, extra_instr_map[qubit_set])
elif (node.op.name, node.op.num_qubits) in instr_map:
self._replace_node(dag, node, instr_map)
else:
bound_target_dag = target_dag
if len(bound_target_dag.op_nodes()) == 1 and len(
bound_target_dag.op_nodes()[0].qargs
) == len(node.qargs):
dag_op = bound_target_dag.op_nodes()[0].op
# dag_op may be the same instance as other ops in the dag,
# so if there is a condition, need to copy
if node.op.condition:
dag_op = dag_op.copy()
dag.substitute_node(node, dag_op, inplace=True)
if bound_target_dag.global_phase:
dag.global_phase += bound_target_dag.global_phase
else:
dag.substitute_node_with_dag(node, bound_target_dag)
if qubit_set in extra_instr_map:
replace_node(node, extra_instr_map[qubit_set])
elif (node.op.name, node.op.num_qubits) in instr_map:
replace_node(node, instr_map)
else:
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
raise TranspilerError(f"BasisTranslator did not map {node.name}.")
dag_updated = True
return dag_updated
apply_translation(dag)
replace_end_time = time.time()
logger.info(
"Basis translation instructions replaced in %.3fs.",
@ -295,6 +248,110 @@ class BasisTranslator(TransformationPass):
return dag
def _replace_node(self, dag, node, instr_map):
target_params, target_dag = instr_map[node.op.name, node.op.num_qubits]
if len(node.op.params) != len(target_params):
raise TranspilerError(
"Translation num_params not equal to op num_params."
"Op: {} {} Translation: {}\n{}".format(
node.op.params, node.op.name, target_params, target_dag
)
)
if node.op.params:
# Convert target to circ and back to assign_parameters, since
# DAGCircuits won't have a ParameterTable.
target_circuit = dag_to_circuit(target_dag)
target_circuit.assign_parameters(
dict(zip_longest(target_params, node.op.params)), inplace=True
)
bound_target_dag = circuit_to_dag(target_circuit)
else:
bound_target_dag = target_dag
if len(bound_target_dag.op_nodes()) == 1 and len(
bound_target_dag.op_nodes()[0].qargs
) == len(node.qargs):
dag_op = bound_target_dag.op_nodes()[0].op
# dag_op may be the same instance as other ops in the dag,
# so if there is a condition, need to copy
if node.op.condition:
dag_op = dag_op.copy()
dag.substitute_node(node, dag_op, inplace=True)
if bound_target_dag.global_phase:
dag.global_phase += bound_target_dag.global_phase
else:
dag.substitute_node_with_dag(node, bound_target_dag)
def _extract_basis_target(
self, dag, qarg_indices, source_basis=None, qargs_local_source_basis=None
):
if source_basis is None:
source_basis = set()
if qargs_local_source_basis is None:
qargs_local_source_basis = defaultdict(set)
for node in dag.op_nodes():
qargs = tuple(qarg_indices[bit] for bit in node.qargs)
if dag.has_calibration_for(node):
continue
# Treat the instruction as on an incomplete basis if the qargs are in the
# qargs_with_non_global_operation dictionary or if any of the qubits in qargs
# are a superset for a non-local operation. For example, if the qargs
# are (0, 1) and that's a global (ie no non-local operations on (0, 1)
# operation but there is a non-local operation on (1,) we need to
# do an extra non-local search for this op to ensure we include any
# single qubit operation for (1,) as valid. This pattern also holds
# true for > 2q ops too (so for 4q operations we need to check for 3q, 2q,
# and 1q operations in the same manner)
if qargs in self._qargs_with_non_global_operation or any(
frozenset(qargs).issuperset(incomplete_qargs)
for incomplete_qargs in self._qargs_with_non_global_operation
):
qargs_local_source_basis[frozenset(qargs)].add((node.name, node.op.num_qubits))
else:
source_basis.add((node.name, node.op.num_qubits))
if isinstance(node.op, ControlFlowOp):
for block in node.op.blocks:
block_dag = circuit_to_dag(block)
source_basis, qargs_local_source_basis = self._extract_basis_target(
block_dag,
qarg_indices,
source_basis=source_basis,
qargs_local_source_basis=qargs_local_source_basis,
)
return source_basis, qargs_local_source_basis
# this could be singledispatchmethod and included in above class when minimum
# supported python version=3.8.
@singledispatch
def _extract_basis(circuit):
return circuit
@_extract_basis.register
def _(dag: DAGCircuit):
for node in dag.op_nodes():
if not dag.has_calibration_for(node):
yield (node.name, node.op.num_qubits)
if isinstance(node.op, ControlFlowOp):
for block in node.op.blocks:
yield from _extract_basis(block)
@_extract_basis.register
def _(circ: QuantumCircuit):
for instr_context in circ.data:
instr, _, _ = instr_context
if not circ.has_calibration_for(instr_context):
yield (instr.name, instr.num_qubits)
if isinstance(instr, ControlFlowOp):
for block in instr.blocks:
yield from _extract_basis(block)
class StopIfBasisRewritable(Exception):
"""Custom exception that signals `retworkx.dijkstra_search` to stop."""
@ -486,8 +543,7 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
source_basis but not affected by basis_transforms will be included
as a key mapping to itself.
"""
example_gates = {(node.op.name, node.op.num_qubits): node.op for node in source_dag.op_nodes()}
example_gates = _get_example_gates(source_dag)
mapped_instrs = {}
for gate_name, gate_num_qubits in source_basis:
@ -523,7 +579,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
]
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
from qiskit.converters import dag_to_circuit
logger.debug(
"Updating transform for mapped instr %s %s from \n%s",
@ -533,7 +588,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
)
for node in doomed_nodes:
from qiskit.converters import circuit_to_dag
replacement = equiv.assign_parameters(
dict(zip_longest(equiv_params, node.op.params))
@ -544,7 +598,6 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
dag.substitute_node_with_dag(node, replacement_dag)
if doomed_nodes and logger.isEnabledFor(logging.DEBUG):
from qiskit.converters import dag_to_circuit
logger.debug(
"Updated transform for mapped instr %s %s to\n%s",
@ -554,3 +607,16 @@ def _compose_transforms(basis_transforms, source_basis, source_dag):
)
return mapped_instrs
def _get_example_gates(source_dag):
def recurse(dag, example_gates=None):
example_gates = example_gates or {}
for node in dag.op_nodes():
example_gates[(node.op.name, node.op.num_qubits)] = node.op
if isinstance(node.op, ControlFlowOp):
for block in node.op.blocks:
example_gates = recurse(circuit_to_dag(block), example_gates)
return example_gates
return recurse(source_dag)

View File

@ -20,7 +20,7 @@ from numpy import pi
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit
from qiskit import transpile
from qiskit.test import QiskitTestCase
from qiskit.circuit import Gate, Parameter, EquivalenceLibrary
from qiskit.circuit import Gate, Parameter, EquivalenceLibrary, Qubit, Clbit
from qiskit.circuit.library import (
U1Gate,
U2Gate,
@ -49,15 +49,15 @@ from qiskit.circuit.library.standard_gates.equivalence_library import (
class OneQubitZeroParamGate(Gate):
"""Mock one qubit zero param gate."""
def __init__(self):
super().__init__("1q0p", 1, [])
def __init__(self, name="1q0p"):
super().__init__(name, 1, [])
class OneQubitOneParamGate(Gate):
"""Mock one qubit one param gate."""
def __init__(self, theta):
super().__init__("1q1p", 1, [theta])
def __init__(self, theta, name="1q1p"):
super().__init__(name, 1, [theta])
class OneQubitOneParamPrimeGate(Gate):
@ -70,22 +70,22 @@ class OneQubitOneParamPrimeGate(Gate):
class OneQubitTwoParamGate(Gate):
"""Mock one qubit two param gate."""
def __init__(self, phi, lam):
super().__init__("1q2p", 1, [phi, lam])
def __init__(self, phi, lam, name="1q2p"):
super().__init__(name, 1, [phi, lam])
class TwoQubitZeroParamGate(Gate):
"""Mock one qubit zero param gate."""
def __init__(self):
super().__init__("2q0p", 2, [])
def __init__(self, name="2q0p"):
super().__init__(name, 2, [])
class VariadicZeroParamGate(Gate):
"""Mock variadic zero param gate."""
def __init__(self, num_qubits):
super().__init__("vq0p", num_qubits, [])
def __init__(self, num_qubits, name="vq0p"):
super().__init__(name, num_qubits, [])
class TestBasisTranslator(QiskitTestCase):
@ -382,6 +382,77 @@ class TestBasisTranslator(QiskitTestCase):
self.assertEqual(actual, expected_dag)
def test_if_else(self):
"""Test a simple if-else with parameters."""
qubits = [Qubit(), Qubit()]
clbits = [Clbit(), Clbit()]
alpha = Parameter("alpha")
beta = Parameter("beta")
gate = OneQubitOneParamGate(alpha)
equiv = QuantumCircuit([qubits[0]])
equiv.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
equiv.append(OneQubitOneParamGate(alpha, name="1q1p_2"), [qubits[0]])
eq_lib = EquivalenceLibrary()
eq_lib.add_equivalence(gate, equiv)
circ = QuantumCircuit(qubits, clbits)
circ.append(OneQubitOneParamGate(beta), [qubits[0]])
circ.measure(qubits[0], clbits[1])
with circ.if_test((clbits[1], 0)) as else_:
circ.append(OneQubitOneParamGate(alpha), [qubits[0]])
circ.append(TwoQubitZeroParamGate(), qubits)
with else_:
circ.append(TwoQubitZeroParamGate(), [qubits[1], qubits[0]])
dag = circuit_to_dag(circ)
dag_translated = BasisTranslator(eq_lib, ["if_else", "1q0p_2", "1q1p_2", "2q0p"]).run(dag)
expected = QuantumCircuit(qubits, clbits)
expected.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
expected.append(OneQubitOneParamGate(beta, name="1q1p_2"), [qubits[0]])
expected.measure(qubits[0], clbits[1])
with expected.if_test((clbits[1], 0)) as else_:
expected.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
expected.append(OneQubitOneParamGate(alpha, name="1q1p_2"), [qubits[0]])
expected.append(TwoQubitZeroParamGate(), qubits)
with else_:
expected.append(TwoQubitZeroParamGate(), [qubits[1], qubits[0]])
dag_expected = circuit_to_dag(expected)
self.assertEqual(dag_translated, dag_expected)
def test_nested_loop(self):
"""Test a simple if-else with parameters."""
qubits = [Qubit(), Qubit()]
clbits = [Clbit(), Clbit()]
cr = ClassicalRegister(bits=clbits)
index1 = Parameter("index1")
alpha = Parameter("alpha")
gate = OneQubitOneParamGate(alpha)
equiv = QuantumCircuit([qubits[0]])
equiv.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
equiv.append(OneQubitOneParamGate(alpha, name="1q1p_2"), [qubits[0]])
eq_lib = EquivalenceLibrary()
eq_lib.add_equivalence(gate, equiv)
circ = QuantumCircuit(qubits, cr)
with circ.for_loop(range(3), loop_parameter=index1) as ind:
with circ.while_loop((cr, 0)):
circ.append(OneQubitOneParamGate(alpha * ind), [qubits[0]])
dag = circuit_to_dag(circ)
dag_translated = BasisTranslator(
eq_lib, ["if_else", "for_loop", "while_loop", "1q0p_2", "1q1p_2"]
).run(dag)
expected = QuantumCircuit(qubits, cr)
with expected.for_loop(range(3), loop_parameter=index1) as ind:
with expected.while_loop((cr, 0)):
expected.append(OneQubitZeroParamGate(name="1q0p_2"), [qubits[0]])
expected.append(OneQubitOneParamGate(alpha * ind, name="1q1p_2"), [qubits[0]])
dag_expected = circuit_to_dag(expected)
self.assertEqual(dag_translated, dag_expected)
class TestUnrollerCompatability(QiskitTestCase):
"""Tests backward compatability with the Unroller pass.