Add DAGCircuit.substitute_node for single node replacement. (#3145)

* Add DAGCircuit.substitute_node function.

* Add DAGCircuit-multi_graph consistency check.

* Add inplace option to DAGCircuit.substitute_node.

* Deprecate DAGNode.pop().

* Move dagcircuit validation to test file.
This commit is contained in:
Kevin Krsulich 2019-09-27 14:56:57 -04:00 committed by Luciano
parent f9e0012aad
commit 0e39b19542
4 changed files with 230 additions and 3 deletions

View File

@ -856,6 +856,63 @@ class DAGCircuit:
self._multi_graph.remove_edge(p[0], self.output_map[w])
def substitute_node(self, node, op, inplace=False):
"""Replace a DAGNode with a single instruction. qargs, cargs and
conditions for the new instruction will be inferred from the node to be
replaced. The new instruction will be checked to match the shape of the
replaced instruction.
Args:
node (DAGNode): Node to be replaced
op (Instruction): The Instruction instance to be added to the DAG
inplace (bool): Optional, default False. If True, existing DAG node
will be modified to include op. Otherwise, a new DAG node will
be used.
Returns:
DAGNode: the new node containing the added instruction.
Raises:
DAGCircuitError: If replacement instruction was incompatible with
location of target node.
"""
if node.type != 'op':
raise DAGCircuitError('Only DAGNodes of type "op" can be replaced.')
if (
node.op.num_qubits != op.num_qubits
or node.op.num_clbits != op.num_clbits
):
raise DAGCircuitError(
'Cannot replace node of width ({} qubits, {} clbits) with '
'instruction of mismatched with ({} qubits, {} clbits).'.format(
node.op.num_qubits, node.op.num_clbits,
op.num_qubits, op.num_clbits))
if inplace:
node.data_dict['op'] = op
return node
self._max_node_id += 1
new_data_dict = node.data_dict.copy()
new_data_dict['op'] = op
new_node = DAGNode(new_data_dict, nid=self._max_node_id)
self._multi_graph.add_node(new_node)
in_edges = self._multi_graph.in_edges(node, data=True)
out_edges = self._multi_graph.out_edges(node, data=True)
self._multi_graph.add_edges_from(
[(src, new_node, data) for src, dest, data in in_edges])
self._multi_graph.add_edges_from(
[(new_node, dest, data) for src, dest, data in out_edges])
self._multi_graph.remove_node(node)
return new_node
def node(self, node_id):
"""Get the node in the dag.

View File

@ -14,6 +14,8 @@
"""Object to represent the information at a node in the DAGCircuit"""
import warnings
from qiskit.exceptions import QiskitError
@ -109,6 +111,7 @@ class DAGNode:
def pop(self, val):
"""Remove the provided value from the dictionary"""
warnings.warn('DAGNode.pop has been deprecated.', DeprecationWarning)
del self.data_dict[val]
@staticmethod

View File

@ -278,13 +278,10 @@ def _transform_gate_for_layout(gate, layout):
mapped_op_node = deepcopy([n for n in gate['graph'].nodes() if n.type == 'op'][0])
# Workaround until #1816, apply mapped to qargs to both DAGNode and op
device_qreg = QuantumRegister(len(layout.get_physical_bits()), 'q')
mapped_qargs = [device_qreg[layout[a]] for a in mapped_op_node.qargs]
mapped_op_node.qargs = mapped_op_node.op.qargs = mapped_qargs
mapped_op_node.pop('name')
return mapped_op_node

View File

@ -17,6 +17,8 @@
import unittest
from ddt import ddt, data
import networkx as nx
from qiskit.dagcircuit import DAGCircuit
@ -29,6 +31,7 @@ from qiskit.circuit import Gate, Instruction
from qiskit.extensions.standard.iden import IdGate
from qiskit.extensions.standard.h import HGate
from qiskit.extensions.standard.cx import CnotGate
from qiskit.extensions.standard.cz import CzGate
from qiskit.extensions.standard.x import XGate
from qiskit.extensions.standard.u1 import U1Gate
from qiskit.extensions.standard.barrier import Barrier
@ -37,6 +40,94 @@ from qiskit.converters import circuit_to_dag
from qiskit.test import QiskitTestCase
def raise_if_dagcircuit_invalid(dag):
"""Validates the internal consistency of a DAGCircuit._multi_graph.
Intended for use in testing.
Raises:
DAGCircuitError: if DAGCircuit._multi_graph is inconsistent.
"""
multi_graph = dag._multi_graph
if not nx.is_directed_acyclic_graph(multi_graph):
raise DAGCircuitError('multi_graph is not a DAG.')
# Every node should be of type in, out, or op.
# All input/output nodes should be present in input_map/output_map.
for node in multi_graph.nodes():
if node.type == 'in':
assert node is dag.input_map[node.wire]
elif node.type == 'out':
assert node is dag.output_map[node.wire]
elif node.type == 'op':
continue
else:
raise DAGCircuitError('Found node of unexpected type: {}'.format(
node.type))
# Shape of node.op should match shape of node.
for node in dag.op_nodes():
assert len(node.qargs) == node.op.num_qubits
assert len(node.cargs) == node.op.num_clbits
# Every edge should be labled with a known wire.
edges_outside_wires = [edge_data['wire']
for source, dest, edge_data
in multi_graph.edges(data=True)
if edge_data['wire'] not in dag.wires]
if edges_outside_wires:
raise DAGCircuitError('multi_graph contains one or more edges ({}) '
'not found in DAGCircuit.wires ({}).'.format(
edges_outside_wires, dag.wires))
# Every wire should have exactly one input node and one output node.
for wire in dag.wires:
in_node = dag.input_map[wire]
out_node = dag.output_map[wire]
assert in_node.wire == wire
assert out_node.wire == wire
assert in_node.type == 'in'
assert out_node.type == 'out'
# Every wire should be propagated by exactly one edge between nodes.
for wire in dag.wires:
cur_node = dag.input_map[wire]
out_node = dag.output_map[wire]
while cur_node != out_node:
out_edges = multi_graph.out_edges(cur_node, data=True)
edges_to_follow = [(src, dest, data) for (src, dest, data) in out_edges
if data['wire'] == wire]
assert len(edges_to_follow) == 1
cur_node = edges_to_follow[0][1]
# Wires can only terminate at input/output nodes.
for op_node in dag.op_nodes():
assert multi_graph.in_degree(op_node) == multi_graph.out_degree(op_node)
# Node input/output edges should match node qarg/carg/condition.
for node in dag.op_nodes():
in_edges = multi_graph.in_edges(node, data=True)
out_edges = multi_graph.out_edges(node, data=True)
in_wires = {data['wire'] for src, dest, data in in_edges}
out_wires = {data['wire'] for src, dest, data in out_edges}
node_cond_bits = set(node.condition[0][:] if node.condition is not None else [])
node_qubits = set(node.qargs)
node_clbits = set(node.cargs)
all_bits = node_qubits | node_clbits | node_cond_bits
assert in_wires == all_bits, 'In-edge wires {} != node bits {}'.format(
in_wires, all_bits)
assert out_wires == all_bits, 'Out-edge wires {} != node bits {}'.format(
out_wires, all_bits)
class TestDagRegisters(QiskitTestCase):
"""Test qreg and creg inside the dag"""
@ -798,6 +889,85 @@ class TestDagSubstitute(QiskitTestCase):
self.dag.substitute_node_with_dag(instr_node, sub_dag)
@ddt
class TestDagSubstituteNode(QiskitTestCase):
"""Test substituting a dagnode with a node."""
def test_substituting_node_with_wrong_width_node_raises(self):
"""Verify replacing a node with one of a different shape raises."""
dag = DAGCircuit()
qr = QuantumRegister(2)
dag.add_qreg(qr)
node_to_be_replaced = dag.apply_operation_back(CnotGate(), [qr[0], qr[1]])
with self.assertRaises(DAGCircuitError) as _:
dag.substitute_node(node_to_be_replaced, Measure())
@data(True, False)
def test_substituting_io_node_raises(self, inplace):
"""Verify replacing an io node raises."""
dag = DAGCircuit()
qr = QuantumRegister(1)
dag.add_qreg(qr)
io_node = next(dag.nodes())
with self.assertRaises(DAGCircuitError) as _:
dag.substitute_node(io_node, HGate(), inplace=inplace)
@data(True, False)
def test_substituting_node_preserves_name_args_condition(self, inplace):
"""Verify name, args and condition are preserved by a substitution."""
dag = DAGCircuit()
qr = QuantumRegister(2)
cr = ClassicalRegister(1)
dag.add_qreg(qr)
dag.add_creg(cr)
dag.apply_operation_back(HGate(), [qr[1]])
node_to_be_replaced = dag.apply_operation_back(CnotGate(), [qr[1], qr[0]],
condition=(cr, 1))
node_to_be_replaced.name = 'test_name'
dag.apply_operation_back(HGate(), [qr[1]])
replacement_node = dag.substitute_node(node_to_be_replaced, CzGate(),
inplace=inplace)
raise_if_dagcircuit_invalid(dag)
self.assertEqual(replacement_node.name, 'test_name')
self.assertEqual(replacement_node.qargs, [qr[1], qr[0]])
self.assertEqual(replacement_node.cargs, [])
self.assertEqual(replacement_node.condition, (cr, 1))
self.assertEqual(replacement_node is node_to_be_replaced, inplace)
@data(True, False)
def test_substituting_node_preserves_parents_children(self, inplace):
"""Verify parents and children are preserved by a substitution."""
qc = QuantumCircuit(3, 2)
qc.cx(0, 1)
qc.cx(1, 2)
qc.rz(0.1, 2)
qc.cx(1, 2)
qc.cx(0, 1)
dag = circuit_to_dag(qc)
node_to_be_replaced = dag.named_nodes('rz')[0]
predecessors = set(dag.predecessors(node_to_be_replaced))
successors = set(dag.successors(node_to_be_replaced))
ancestors = dag.ancestors(node_to_be_replaced)
descendants = dag.descendants(node_to_be_replaced)
replacement_node = dag.substitute_node(node_to_be_replaced, U1Gate(0.1),
inplace=inplace)
raise_if_dagcircuit_invalid(dag)
self.assertEqual(set(dag.predecessors(replacement_node)), predecessors)
self.assertEqual(set(dag.successors(replacement_node)), successors)
self.assertEqual(dag.ancestors(replacement_node), ancestors)
self.assertEqual(dag.descendants(replacement_node), descendants)
self.assertEqual(replacement_node is node_to_be_replaced, inplace)
class TestDagProperties(QiskitTestCase):
"""Test the DAG properties.
"""