mirror of https://github.com/Qiskit/qiskit.git
Add DAGCircuit.substitute_node for single node replacement. (#3145)
* Add DAGCircuit.substitute_node function. * Add DAGCircuit-multi_graph consistency check. * Add inplace option to DAGCircuit.substitute_node. * Deprecate DAGNode.pop(). * Move dagcircuit validation to test file.
This commit is contained in:
parent
f9e0012aad
commit
0e39b19542
|
@ -856,6 +856,63 @@ class DAGCircuit:
|
|||
|
||||
self._multi_graph.remove_edge(p[0], self.output_map[w])
|
||||
|
||||
def substitute_node(self, node, op, inplace=False):
|
||||
"""Replace a DAGNode with a single instruction. qargs, cargs and
|
||||
conditions for the new instruction will be inferred from the node to be
|
||||
replaced. The new instruction will be checked to match the shape of the
|
||||
replaced instruction.
|
||||
|
||||
Args:
|
||||
node (DAGNode): Node to be replaced
|
||||
op (Instruction): The Instruction instance to be added to the DAG
|
||||
inplace (bool): Optional, default False. If True, existing DAG node
|
||||
will be modified to include op. Otherwise, a new DAG node will
|
||||
be used.
|
||||
|
||||
Returns:
|
||||
DAGNode: the new node containing the added instruction.
|
||||
|
||||
Raises:
|
||||
DAGCircuitError: If replacement instruction was incompatible with
|
||||
location of target node.
|
||||
"""
|
||||
|
||||
if node.type != 'op':
|
||||
raise DAGCircuitError('Only DAGNodes of type "op" can be replaced.')
|
||||
|
||||
if (
|
||||
node.op.num_qubits != op.num_qubits
|
||||
or node.op.num_clbits != op.num_clbits
|
||||
):
|
||||
raise DAGCircuitError(
|
||||
'Cannot replace node of width ({} qubits, {} clbits) with '
|
||||
'instruction of mismatched with ({} qubits, {} clbits).'.format(
|
||||
node.op.num_qubits, node.op.num_clbits,
|
||||
op.num_qubits, op.num_clbits))
|
||||
|
||||
if inplace:
|
||||
node.data_dict['op'] = op
|
||||
return node
|
||||
|
||||
self._max_node_id += 1
|
||||
new_data_dict = node.data_dict.copy()
|
||||
new_data_dict['op'] = op
|
||||
new_node = DAGNode(new_data_dict, nid=self._max_node_id)
|
||||
|
||||
self._multi_graph.add_node(new_node)
|
||||
|
||||
in_edges = self._multi_graph.in_edges(node, data=True)
|
||||
out_edges = self._multi_graph.out_edges(node, data=True)
|
||||
|
||||
self._multi_graph.add_edges_from(
|
||||
[(src, new_node, data) for src, dest, data in in_edges])
|
||||
self._multi_graph.add_edges_from(
|
||||
[(new_node, dest, data) for src, dest, data in out_edges])
|
||||
|
||||
self._multi_graph.remove_node(node)
|
||||
|
||||
return new_node
|
||||
|
||||
def node(self, node_id):
|
||||
"""Get the node in the dag.
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
"""Object to represent the information at a node in the DAGCircuit"""
|
||||
|
||||
import warnings
|
||||
|
||||
from qiskit.exceptions import QiskitError
|
||||
|
||||
|
||||
|
@ -109,6 +111,7 @@ class DAGNode:
|
|||
|
||||
def pop(self, val):
|
||||
"""Remove the provided value from the dictionary"""
|
||||
warnings.warn('DAGNode.pop has been deprecated.', DeprecationWarning)
|
||||
del self.data_dict[val]
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -278,13 +278,10 @@ def _transform_gate_for_layout(gate, layout):
|
|||
|
||||
mapped_op_node = deepcopy([n for n in gate['graph'].nodes() if n.type == 'op'][0])
|
||||
|
||||
# Workaround until #1816, apply mapped to qargs to both DAGNode and op
|
||||
device_qreg = QuantumRegister(len(layout.get_physical_bits()), 'q')
|
||||
mapped_qargs = [device_qreg[layout[a]] for a in mapped_op_node.qargs]
|
||||
mapped_op_node.qargs = mapped_op_node.op.qargs = mapped_qargs
|
||||
|
||||
mapped_op_node.pop('name')
|
||||
|
||||
return mapped_op_node
|
||||
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from ddt import ddt, data
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from qiskit.dagcircuit import DAGCircuit
|
||||
|
@ -29,6 +31,7 @@ from qiskit.circuit import Gate, Instruction
|
|||
from qiskit.extensions.standard.iden import IdGate
|
||||
from qiskit.extensions.standard.h import HGate
|
||||
from qiskit.extensions.standard.cx import CnotGate
|
||||
from qiskit.extensions.standard.cz import CzGate
|
||||
from qiskit.extensions.standard.x import XGate
|
||||
from qiskit.extensions.standard.u1 import U1Gate
|
||||
from qiskit.extensions.standard.barrier import Barrier
|
||||
|
@ -37,6 +40,94 @@ from qiskit.converters import circuit_to_dag
|
|||
from qiskit.test import QiskitTestCase
|
||||
|
||||
|
||||
def raise_if_dagcircuit_invalid(dag):
|
||||
"""Validates the internal consistency of a DAGCircuit._multi_graph.
|
||||
Intended for use in testing.
|
||||
|
||||
Raises:
|
||||
DAGCircuitError: if DAGCircuit._multi_graph is inconsistent.
|
||||
"""
|
||||
|
||||
multi_graph = dag._multi_graph
|
||||
|
||||
if not nx.is_directed_acyclic_graph(multi_graph):
|
||||
raise DAGCircuitError('multi_graph is not a DAG.')
|
||||
|
||||
# Every node should be of type in, out, or op.
|
||||
# All input/output nodes should be present in input_map/output_map.
|
||||
for node in multi_graph.nodes():
|
||||
if node.type == 'in':
|
||||
assert node is dag.input_map[node.wire]
|
||||
elif node.type == 'out':
|
||||
assert node is dag.output_map[node.wire]
|
||||
elif node.type == 'op':
|
||||
continue
|
||||
else:
|
||||
raise DAGCircuitError('Found node of unexpected type: {}'.format(
|
||||
node.type))
|
||||
|
||||
# Shape of node.op should match shape of node.
|
||||
for node in dag.op_nodes():
|
||||
assert len(node.qargs) == node.op.num_qubits
|
||||
assert len(node.cargs) == node.op.num_clbits
|
||||
|
||||
# Every edge should be labled with a known wire.
|
||||
edges_outside_wires = [edge_data['wire']
|
||||
for source, dest, edge_data
|
||||
in multi_graph.edges(data=True)
|
||||
if edge_data['wire'] not in dag.wires]
|
||||
if edges_outside_wires:
|
||||
raise DAGCircuitError('multi_graph contains one or more edges ({}) '
|
||||
'not found in DAGCircuit.wires ({}).'.format(
|
||||
edges_outside_wires, dag.wires))
|
||||
|
||||
# Every wire should have exactly one input node and one output node.
|
||||
for wire in dag.wires:
|
||||
in_node = dag.input_map[wire]
|
||||
out_node = dag.output_map[wire]
|
||||
|
||||
assert in_node.wire == wire
|
||||
assert out_node.wire == wire
|
||||
assert in_node.type == 'in'
|
||||
assert out_node.type == 'out'
|
||||
|
||||
# Every wire should be propagated by exactly one edge between nodes.
|
||||
for wire in dag.wires:
|
||||
cur_node = dag.input_map[wire]
|
||||
out_node = dag.output_map[wire]
|
||||
|
||||
while cur_node != out_node:
|
||||
out_edges = multi_graph.out_edges(cur_node, data=True)
|
||||
edges_to_follow = [(src, dest, data) for (src, dest, data) in out_edges
|
||||
if data['wire'] == wire]
|
||||
|
||||
assert len(edges_to_follow) == 1
|
||||
cur_node = edges_to_follow[0][1]
|
||||
|
||||
# Wires can only terminate at input/output nodes.
|
||||
for op_node in dag.op_nodes():
|
||||
assert multi_graph.in_degree(op_node) == multi_graph.out_degree(op_node)
|
||||
|
||||
# Node input/output edges should match node qarg/carg/condition.
|
||||
for node in dag.op_nodes():
|
||||
in_edges = multi_graph.in_edges(node, data=True)
|
||||
out_edges = multi_graph.out_edges(node, data=True)
|
||||
|
||||
in_wires = {data['wire'] for src, dest, data in in_edges}
|
||||
out_wires = {data['wire'] for src, dest, data in out_edges}
|
||||
|
||||
node_cond_bits = set(node.condition[0][:] if node.condition is not None else [])
|
||||
node_qubits = set(node.qargs)
|
||||
node_clbits = set(node.cargs)
|
||||
|
||||
all_bits = node_qubits | node_clbits | node_cond_bits
|
||||
|
||||
assert in_wires == all_bits, 'In-edge wires {} != node bits {}'.format(
|
||||
in_wires, all_bits)
|
||||
assert out_wires == all_bits, 'Out-edge wires {} != node bits {}'.format(
|
||||
out_wires, all_bits)
|
||||
|
||||
|
||||
class TestDagRegisters(QiskitTestCase):
|
||||
"""Test qreg and creg inside the dag"""
|
||||
|
||||
|
@ -798,6 +889,85 @@ class TestDagSubstitute(QiskitTestCase):
|
|||
self.dag.substitute_node_with_dag(instr_node, sub_dag)
|
||||
|
||||
|
||||
@ddt
|
||||
class TestDagSubstituteNode(QiskitTestCase):
|
||||
"""Test substituting a dagnode with a node."""
|
||||
|
||||
def test_substituting_node_with_wrong_width_node_raises(self):
|
||||
"""Verify replacing a node with one of a different shape raises."""
|
||||
dag = DAGCircuit()
|
||||
qr = QuantumRegister(2)
|
||||
dag.add_qreg(qr)
|
||||
node_to_be_replaced = dag.apply_operation_back(CnotGate(), [qr[0], qr[1]])
|
||||
|
||||
with self.assertRaises(DAGCircuitError) as _:
|
||||
dag.substitute_node(node_to_be_replaced, Measure())
|
||||
|
||||
@data(True, False)
|
||||
def test_substituting_io_node_raises(self, inplace):
|
||||
"""Verify replacing an io node raises."""
|
||||
dag = DAGCircuit()
|
||||
qr = QuantumRegister(1)
|
||||
dag.add_qreg(qr)
|
||||
|
||||
io_node = next(dag.nodes())
|
||||
|
||||
with self.assertRaises(DAGCircuitError) as _:
|
||||
dag.substitute_node(io_node, HGate(), inplace=inplace)
|
||||
|
||||
@data(True, False)
|
||||
def test_substituting_node_preserves_name_args_condition(self, inplace):
|
||||
"""Verify name, args and condition are preserved by a substitution."""
|
||||
dag = DAGCircuit()
|
||||
qr = QuantumRegister(2)
|
||||
cr = ClassicalRegister(1)
|
||||
dag.add_qreg(qr)
|
||||
dag.add_creg(cr)
|
||||
dag.apply_operation_back(HGate(), [qr[1]])
|
||||
node_to_be_replaced = dag.apply_operation_back(CnotGate(), [qr[1], qr[0]],
|
||||
condition=(cr, 1))
|
||||
node_to_be_replaced.name = 'test_name'
|
||||
dag.apply_operation_back(HGate(), [qr[1]])
|
||||
|
||||
replacement_node = dag.substitute_node(node_to_be_replaced, CzGate(),
|
||||
inplace=inplace)
|
||||
|
||||
raise_if_dagcircuit_invalid(dag)
|
||||
self.assertEqual(replacement_node.name, 'test_name')
|
||||
self.assertEqual(replacement_node.qargs, [qr[1], qr[0]])
|
||||
self.assertEqual(replacement_node.cargs, [])
|
||||
self.assertEqual(replacement_node.condition, (cr, 1))
|
||||
|
||||
self.assertEqual(replacement_node is node_to_be_replaced, inplace)
|
||||
|
||||
@data(True, False)
|
||||
def test_substituting_node_preserves_parents_children(self, inplace):
|
||||
"""Verify parents and children are preserved by a substitution."""
|
||||
qc = QuantumCircuit(3, 2)
|
||||
qc.cx(0, 1)
|
||||
qc.cx(1, 2)
|
||||
qc.rz(0.1, 2)
|
||||
qc.cx(1, 2)
|
||||
qc.cx(0, 1)
|
||||
dag = circuit_to_dag(qc)
|
||||
node_to_be_replaced = dag.named_nodes('rz')[0]
|
||||
predecessors = set(dag.predecessors(node_to_be_replaced))
|
||||
successors = set(dag.successors(node_to_be_replaced))
|
||||
ancestors = dag.ancestors(node_to_be_replaced)
|
||||
descendants = dag.descendants(node_to_be_replaced)
|
||||
|
||||
replacement_node = dag.substitute_node(node_to_be_replaced, U1Gate(0.1),
|
||||
inplace=inplace)
|
||||
|
||||
raise_if_dagcircuit_invalid(dag)
|
||||
self.assertEqual(set(dag.predecessors(replacement_node)), predecessors)
|
||||
self.assertEqual(set(dag.successors(replacement_node)), successors)
|
||||
self.assertEqual(dag.ancestors(replacement_node), ancestors)
|
||||
self.assertEqual(dag.descendants(replacement_node), descendants)
|
||||
|
||||
self.assertEqual(replacement_node is node_to_be_replaced, inplace)
|
||||
|
||||
|
||||
class TestDagProperties(QiskitTestCase):
|
||||
"""Test the DAG properties.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue