Add ax kwarg to matplotlib visualization functions (#3053)

* Add ax kwarg to matplotlib visualization functions

This commit adds support for passing a matplotlib axes to visualization
functions. This enables integrating the qiskit visualizations into a
larger matplotlib visualization. When an ax kwarg is set that input will
be used for the output visualization and a new figure will not be
generated or returned.

Fixes #1950

* typo in release note

* Add ax kwarg to matplotlib visualization functions

This commit adds support for passing a matplotlib axes to visualization
functions. This enables integrating the qiskit visualizations into a
larger matplotlib visualization. When an ax kwarg is set that input will
be used for the output visualization and a new figure will not be
generated or returned.

Fixes #1950

* Fix lint and typos

* pylint: disable=invalid-name

* Fix lost lines from rebase

As part of rebasing and cleaning up the branch from several failed merge
attempts a few lines from plot_state_city() were lost. This commit fixes
that mistake and restores the z axis labels which were missing from the
imaginary component subplot.

* Fix rebase typo

* Revert "pylint: disable=invalid-name"

This reverts commit 5566611719.
This commit is contained in:
Matthew Treinish 2019-10-08 08:22:51 -04:00 committed by Luciano
parent 1cb65623b4
commit 352f940dfa
7 changed files with 311 additions and 156 deletions

View File

@ -172,7 +172,7 @@ attr-rgx=[a-z_][a-z0-9_]{2,30}$
attr-name-hint=[a-z_][a-z0-9_]{2,30}$
# Regular expression matching correct argument names
argument-rgx=[a-z_][a-z0-9_]{2,30}$
argument-rgx=[a-z_][a-z0-9_]{2,30}|ax$
# Naming hint for argument names
argument-name-hint=[a-z_][a-z0-9_]{2,30}$

View File

@ -560,8 +560,8 @@ class QuantumCircuit:
def draw(self, scale=0.7, filename=None, style=None, output=None,
interactive=False, line_length=None, plot_barriers=True,
reverse_bits=False, justify=None, idle_wires=True, vertical_compression='medium',
with_layout=True, fold=None):
reverse_bits=False, justify=None, vertical_compression='medium', idle_wires=True,
with_layout=True, fold=None, ax=None):
"""Draw the quantum circuit
Using the output parameter you can specify the format. The choices are:
@ -610,12 +610,19 @@ class QuantumCircuit:
guess the console width using `shutil.get_terminal_size()`. However, if
running in jupyter, the default line length is set to 80 characters.
In `mpl` is the amount of operations before folding. Default is 25.
ax (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. Additionally, if specified there
will be no returned Figure since it is redundant. This is only used
when the ``output`` kwarg is set to use the ``mpl`` backend. It
will be silently ignored with all other outputs.
Returns:
PIL.Image or matplotlib.figure or str or TextDrawing:
* PIL.Image: (output `latex`) an in-memory representation of the
image of the circuit diagram.
* matplotlib.figure: (output `mpl`) a matplotlib figure object
for the circuit diagram.
for the circuit diagram, if the ``ax`` kwarg is not set.
* str: (output `latex_source`). The LaTeX source code.
* TextDrawing: (output `text`). A drawing that can be printed as
ascii art
@ -635,7 +642,9 @@ class QuantumCircuit:
justify=justify,
vertical_compression=vertical_compression,
idle_wires=idle_wires,
with_layout=with_layout, fold=fold)
with_layout=with_layout,
fold=fold,
ax=ax)
def size(self):
"""Returns total number of gate operations in circuit.

View File

@ -56,7 +56,8 @@ def circuit_drawer(circuit,
vertical_compression='medium',
idle_wires=True,
with_layout=True,
fold=None):
fold=None,
ax=None):
"""Draw a quantum circuit to different formats (set by output parameter):
0. text: ASCII art TextDrawing that can be printed in the console.
1. latex: high-quality images, but heavy external software dependencies
@ -103,11 +104,17 @@ def circuit_drawer(circuit,
guess the console width using `shutil.get_terminal_size()`. However, if
running in jupyter, the default line length is set to 80 characters.
In `mpl` is the amount of operations before folding. Default is 25.
ax (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. Additionally, if specified there
will be no returned Figure since it is redundant. This is only used
when the ``output`` kwarg is set to use the ``mpl`` backend. It
will be silently ignored with all other outputs.
Returns:
PIL.Image: (output `latex`) an in-memory representation of the image
of the circuit diagram.
matplotlib.figure: (output `mpl`) a matplotlib figure object for the
circuit diagram.
circuit diagram, if the ``ax`` kwarg is not set
String: (output `latex_source`). The LaTeX source code.
TextDrawing: (output `text`). A drawing that can be printed as ascii art
Raises:
@ -257,7 +264,8 @@ def circuit_drawer(circuit,
justify=justify,
idle_wires=idle_wires,
with_layout=with_layout,
fold=fold)
fold=fold,
ax=ax)
else:
raise exceptions.VisualizationError(
'Invalid output type %s selected. The only valid choices '
@ -538,7 +546,8 @@ def _matplotlib_circuit_drawer(circuit,
justify=None,
idle_wires=True,
with_layout=True,
fold=None):
fold=None,
ax=None):
"""Draw a quantum circuit based on matplotlib.
If `%matplotlib inline` is invoked in a Jupyter notebook, it visualizes a circuit inline.
We recommend `%config InlineBackend.figure_format = 'svg'` for the inline visualization.
@ -558,8 +567,14 @@ def _matplotlib_circuit_drawer(circuit,
with_layout (bool): Include layout information, with labels on the physical
layout. Default: True.
fold (int): amount ops allowed before folding. Default is 25.
ax (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. Additionally, if specified there
will be no returned Figure since it is redundant.
Returns:
matplotlib.figure: a matplotlib figure object for the circuit diagram
if the ``ax`` kwarg is not set.
"""
qregs, cregs, ops = utils._get_layered_instructions(circuit,
@ -576,5 +591,6 @@ def _matplotlib_circuit_drawer(circuit,
qcd = _matplotlib.MatplotlibDrawer(qregs, cregs, ops, scale=scale, style=style,
plot_barriers=plot_barriers,
reverse_bits=reverse_bits, layout=layout, fold=fold)
reverse_bits=reverse_bits, layout=layout,
fold=fold, ax=ax)
return qcd.draw(filename)

View File

@ -52,7 +52,7 @@ DIST_MEAS = {'hamming': hamming_distance}
def plot_histogram(data, figsize=(7, 5), color=None, number_to_keep=None,
sort='asc', target_string=None,
legend=None, bar_labels=True, title=None):
legend=None, bar_labels=True, title=None, ax=None):
"""Plot a histogram of data.
Args:
@ -69,9 +69,14 @@ def plot_histogram(data, figsize=(7, 5), color=None, number_to_keep=None,
list or 1 if it's a dict)
bar_labels (bool): Label each bar in histogram with probability value.
title (str): A string to use for the plot title
ax (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. Additionally, if specified there
will be no returned Figure since it is redundant.
Returns:
matplotlib.Figure: A figure for the rendered histogram.
matplotlib.Figure: A figure for the rendered histogram, if the ``ax``
kwarg is not set.
Raises:
ImportError: Matplotlib not available.
@ -95,8 +100,11 @@ def plot_histogram(data, figsize=(7, 5), color=None, number_to_keep=None,
raise VisualizationError("Length of legendL (%s) doesn't match "
"number of input executions: %s" %
(len(legend), len(data)))
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = None
fig, ax = plt.subplots(figsize=figsize)
labels = list(sorted(
functools.reduce(lambda x, y: x.union(y.keys()), data, set())))
if number_to_keep is not None:

View File

@ -12,7 +12,7 @@
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name,missing-docstring
# pylint: disable=invalid-name,missing-docstring,inconsistent-return-statements
"""mpl circuit visualization backend."""
@ -103,7 +103,7 @@ class Anchor:
class MatplotlibDrawer:
def __init__(self, qregs, cregs, ops,
scale=1.0, style=None, plot_barriers=True,
reverse_bits=False, layout=None, fold=25):
reverse_bits=False, layout=None, fold=25, ax=None):
if not HAS_MATPLOTLIB:
raise ImportError('The class MatplotlibDrawer needs matplotlib. '
@ -145,14 +145,21 @@ class MatplotlibDrawer:
with open(style, 'r') as infile:
dic = json.load(infile)
self._style.set_style(dic)
if ax is None:
self.return_fig = True
self.figure = plt.figure()
self.figure.patch.set_facecolor(color=self._style.bg)
self.ax = self.figure.add_subplot(111)
else:
self.return_fig = False
self.ax = ax
self.figure = ax.get_figure()
self.fold = self._style.fold or fold # self._style.fold should be removed after 0.10
# TODO: self._style.fold should be removed after deprecation
self.fold = self._style.fold or fold
if self.fold < 2:
self.fold = -1
self.figure = plt.figure()
self.figure.patch.set_facecolor(color=self._style.bg)
self.ax = self.figure.add_subplot(111)
self.ax.axis('off')
self.ax.set_aspect('equal')
self.ax.tick_params(labelbottom=False, labeltop=False,
@ -503,10 +510,11 @@ class MatplotlibDrawer:
if filename:
self.figure.savefig(filename, dpi=self._style.dpi,
bbox_inches='tight')
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(self.figure)
return self.figure
if self.return_fig:
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(self.figure)
return self.figure
def _draw_regs(self):

View File

@ -13,6 +13,7 @@
# that they have been altered from the originals.
# pylint: disable=invalid-name,ungrouped-imports,import-error
# pylint: disable=inconsistent-return-statements
"""
Visualization functions for quantum states.
@ -58,15 +59,29 @@ if HAS_MATPLOTLIB:
FancyArrowPatch.draw(self, renderer)
def plot_state_hinton(rho, title='', figsize=None):
def plot_state_hinton(rho, title='', figsize=None, ax_real=None, ax_imag=None):
"""Plot a hinton diagram for the quanum state.
Args:
rho (ndarray): Numpy array for state vector or density matrix.
title (str): a string that represents the plot title
figsize (tuple): Figure size in inches.
ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. If this is specified without an
ax_imag only the real component plot will be generated.
Additionally, if specified there will be no returned Figure since
it is redundant.
ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. If this is specified without an
ax_imag only the real component plot will be generated.
Additionally, if specified there will be no returned Figure since
it is redundant.
Returns:
matplotlib.Figure: The matplotlib.Figure of the visualization
matplotlib.Figure: The matplotlib.Figure of the visualization if
neither ax_real or ax_imag is set.
Raises:
ImportError: Requires matplotlib.
@ -77,7 +92,15 @@ def plot_state_hinton(rho, title='', figsize=None):
if figsize is None:
figsize = (8, 5)
num = int(np.log2(len(rho)))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
if not ax_real and not ax_imag:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
else:
if ax_real:
fig = ax_real.get_figure()
else:
fig = ax_imag.get_figure()
ax1 = ax_real
ax2 = ax_imag
max_weight = 2 ** np.ceil(np.log(np.abs(rho).max()) / np.log(2))
datareal = np.real(rho)
dataimag = np.imag(rho)
@ -86,52 +109,55 @@ def plot_state_hinton(rho, title='', figsize=None):
lx = len(datareal[0]) # Work out matrix dimensions
ly = len(datareal[:, 0])
# Real
ax1.patch.set_facecolor('gray')
ax1.set_aspect('equal', 'box')
ax1.xaxis.set_major_locator(plt.NullLocator())
ax1.yaxis.set_major_locator(plt.NullLocator())
if ax1:
ax1.patch.set_facecolor('gray')
ax1.set_aspect('equal', 'box')
ax1.xaxis.set_major_locator(plt.NullLocator())
ax1.yaxis.set_major_locator(plt.NullLocator())
for (x, y), w in np.ndenumerate(datareal):
color = 'white' if w > 0 else 'black'
size = np.sqrt(np.abs(w) / max_weight)
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
facecolor=color, edgecolor=color)
ax1.add_patch(rect)
for (x, y), w in np.ndenumerate(datareal):
color = 'white' if w > 0 else 'black'
size = np.sqrt(np.abs(w) / max_weight)
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
facecolor=color, edgecolor=color)
ax1.add_patch(rect)
ax1.set_xticks(np.arange(0, lx+0.5, 1))
ax1.set_yticks(np.arange(0, ly+0.5, 1))
ax1.set_yticklabels(row_names, fontsize=14)
ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
ax1.autoscale_view()
ax1.invert_yaxis()
ax1.set_title('Re[$\\rho$]', fontsize=14)
ax1.set_xticks(np.arange(0, lx+0.5, 1))
ax1.set_yticks(np.arange(0, ly+0.5, 1))
ax1.set_yticklabels(row_names, fontsize=14)
ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
ax1.autoscale_view()
ax1.invert_yaxis()
ax1.set_title('Re[$\\rho$]', fontsize=14)
# Imaginary
ax2.patch.set_facecolor('gray')
ax2.set_aspect('equal', 'box')
ax2.xaxis.set_major_locator(plt.NullLocator())
ax2.yaxis.set_major_locator(plt.NullLocator())
if ax2:
ax2.patch.set_facecolor('gray')
ax2.set_aspect('equal', 'box')
ax2.xaxis.set_major_locator(plt.NullLocator())
ax2.yaxis.set_major_locator(plt.NullLocator())
for (x, y), w in np.ndenumerate(dataimag):
color = 'white' if w > 0 else 'black'
size = np.sqrt(np.abs(w) / max_weight)
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
facecolor=color, edgecolor=color)
ax2.add_patch(rect)
for (x, y), w in np.ndenumerate(dataimag):
color = 'white' if w > 0 else 'black'
size = np.sqrt(np.abs(w) / max_weight)
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
facecolor=color, edgecolor=color)
ax2.add_patch(rect)
ax2.set_xticks(np.arange(0, lx+0.5, 1))
ax2.set_yticks(np.arange(0, ly+0.5, 1))
ax2.set_yticklabels(row_names, fontsize=14)
ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
ax2.set_xticks(np.arange(0, lx+0.5, 1))
ax2.set_yticks(np.arange(0, ly+0.5, 1))
ax2.set_yticklabels(row_names, fontsize=14)
ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
ax2.autoscale_view()
ax2.invert_yaxis()
ax2.set_title('Im[$\\rho$]', fontsize=14)
ax2.autoscale_view()
ax2.invert_yaxis()
ax2.set_title('Im[$\\rho$]', fontsize=14)
if title:
fig.suptitle(title, fontsize=16)
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(fig)
return fig
if ax_real is None and ax_imag is None:
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(fig)
return fig
def plot_bloch_vector(bloch, title="", ax=None, figsize=None):
@ -142,7 +168,8 @@ def plot_bloch_vector(bloch, title="", ax=None, figsize=None):
Args:
bloch (list[double]): array of three elements where [<x>, <y>, <z>]
title (str): a string that represents the plot title
ax (matplotlib.Axes): An Axes to use for rendering the bloch sphere
ax (matplotlib.axes.Axes): An Axes to use for rendering the bloch
sphere
figsize (tuple): Figure size in inches. Has no effect is passing `ax`.
Returns:
@ -179,7 +206,7 @@ def plot_bloch_multivector(rho, title='', figsize=None):
figsize (tuple): Has no effect, here for compatibility only.
Returns:
Figure: A matplotlib figure instance if `ax = None`.
Figure: A matplotlib figure instance.
Raises:
ImportError: Requires matplotlib.
@ -210,7 +237,7 @@ def plot_bloch_multivector(rho, title='', figsize=None):
def plot_state_city(rho, title="", figsize=None, color=None,
alpha=1):
alpha=1, ax_real=None, ax_imag=None):
"""Plot the cityscape of quantum state.
Plot two 3d bar graphs (two dimensional) of the real and imaginary
@ -223,8 +250,22 @@ def plot_state_city(rho, title="", figsize=None, color=None,
color (list): A list of len=2 giving colors for real and
imaginary components of matrix elements.
alpha (float): Transparency value for bars
ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. If this is specified without an
ax_imag only the real component plot will be generated.
Additionally, if specified there will be no returned Figure since
it is redundant.
ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. If this is specified without an
ax_imag only the real component plot will be generated.
Additionally, if specified there will be no returned Figure since
it is redundant.
Returns:
matplotlib.Figure: The matplotlib.Figure of the visualization
matplotlib.Figure: The matplotlib.Figure of the visualization if the
``ax_real`` and ``ax_imag`` kwargs are not set
Raises:
ImportError: Requires matplotlib.
@ -267,99 +308,112 @@ def plot_state_city(rho, title="", figsize=None, color=None,
color[0] = "#648fff"
if color[1] is None:
color[1] = "#648fff"
if ax_real is None and ax_imag is None:
# set default figure size
if figsize is None:
figsize = (15, 5)
# set default figure size
if figsize is None:
figsize = (15, 5)
fig = plt.figure(figsize=figsize)
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
fig = plt.figure(figsize=figsize)
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
elif ax_real is not None:
fig = ax_real.get_figure()
ax1 = ax_real
if ax_imag is not None:
ax2 = ax_imag
else:
fig = ax_imag.get_figure()
ax1 = None
ax2 = ax_imag
x = [0, max(xpos)+0.5, max(xpos)+0.5, 0]
y = [0, 0, max(ypos)+0.5, max(ypos)+0.5]
z = [0, 0, 0, 0]
verts = [list(zip(x, y, z))]
fc1 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzr, color[0])
for idx, cur_zpos in enumerate(zpos):
if dzr[idx] > 0:
zorder = 2
else:
zorder = 0
b1 = ax1.bar3d(xpos[idx], ypos[idx], cur_zpos,
dx[idx], dy[idx], dzr[idx],
alpha=alpha, zorder=zorder)
b1.set_facecolors(fc1[6*idx:6*idx+6])
if ax1 is not None:
fc1 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzr, color[0])
for idx, cur_zpos in enumerate(zpos):
if dzr[idx] > 0:
zorder = 2
else:
zorder = 0
b1 = ax1.bar3d(xpos[idx], ypos[idx], cur_zpos,
dx[idx], dy[idx], dzr[idx],
alpha=alpha, zorder=zorder)
b1.set_facecolors(fc1[6*idx:6*idx+6])
pc1 = Poly3DCollection(verts, alpha=0.15, facecolor='k',
linewidths=1, zorder=1)
pc1 = Poly3DCollection(verts, alpha=0.15, facecolor='k',
linewidths=1, zorder=1)
if min(dzr) < 0 < max(dzr):
ax1.add_collection3d(pc1)
if min(dzr) < 0 < max(dzr):
ax1.add_collection3d(pc1)
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
fc2 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzi, color[1])
for idx, cur_zpos in enumerate(zpos):
if dzi[idx] > 0:
zorder = 2
else:
zorder = 0
b2 = ax2.bar3d(xpos[idx], ypos[idx], cur_zpos,
dx[idx], dy[idx], dzi[idx],
alpha=alpha, zorder=zorder)
b2.set_facecolors(fc2[6*idx:6*idx+6])
if ax2 is not None:
fc2 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzi, color[1])
for idx, cur_zpos in enumerate(zpos):
if dzi[idx] > 0:
zorder = 2
else:
zorder = 0
b2 = ax2.bar3d(xpos[idx], ypos[idx], cur_zpos,
dx[idx], dy[idx], dzi[idx],
alpha=alpha, zorder=zorder)
b2.set_facecolors(fc2[6*idx:6*idx+6])
pc2 = Poly3DCollection(verts, alpha=0.2, facecolor='k',
linewidths=1, zorder=1)
if min(dzi) < 0 < max(dzi):
ax2.add_collection3d(pc2)
ax1.set_xticks(np.arange(0.5, lx+0.5, 1))
ax1.set_yticks(np.arange(0.5, ly+0.5, 1))
max_dzr = max(dzr)
min_dzr = min(dzr)
if max_dzr != min_dzr:
ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr)+1e-9)
else:
if min_dzr == 0:
pc2 = Poly3DCollection(verts, alpha=0.2, facecolor='k',
linewidths=1, zorder=1)
if min(dzi) < 0 < max(dzi):
ax2.add_collection3d(pc2)
if ax1 is not None:
ax1.set_xticks(np.arange(0.5, lx+0.5, 1))
ax1.set_yticks(np.arange(0.5, ly+0.5, 1))
max_dzr = max(dzr)
min_dzr = min(dzr)
if max_dzr != min_dzr:
ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr)+1e-9)
else:
ax1.axes.set_zlim3d(auto=True)
ax1.zaxis.set_major_locator(MaxNLocator(5))
ax1.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
ax1.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
ax1.set_zlabel("Real[rho]", fontsize=14)
for tick in ax1.zaxis.get_major_ticks():
tick.label.set_fontsize(14)
ax2.set_xticks(np.arange(0.5, lx+0.5, 1))
ax2.set_yticks(np.arange(0.5, ly+0.5, 1))
min_dzi = np.min(dzi)
max_dzi = np.max(dzi)
if min_dzi != max_dzi:
eps = 0
ax2.zaxis.set_major_locator(MaxNLocator(5))
ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi)+eps)
else:
if min_dzi == 0:
ax2.set_zticks([0])
eps = 1e-9
if min_dzr == 0:
ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr)+1e-9)
else:
ax1.axes.set_zlim3d(auto=True)
ax1.zaxis.set_major_locator(MaxNLocator(5))
ax1.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
ax1.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
ax1.set_zlabel("Real[rho]", fontsize=14)
for tick in ax1.zaxis.get_major_ticks():
tick.label.set_fontsize(14)
if ax2 is not None:
ax2.set_xticks(np.arange(0.5, lx+0.5, 1))
ax2.set_yticks(np.arange(0.5, ly+0.5, 1))
min_dzi = np.min(dzi)
max_dzi = np.max(dzi)
if min_dzi != max_dzi:
eps = 0
ax2.zaxis.set_major_locator(MaxNLocator(5))
ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi)+eps)
else:
ax2.axes.set_zlim3d(auto=True)
ax2.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
ax2.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
ax2.set_zlabel("Imag[rho]", fontsize=14)
for tick in ax2.zaxis.get_major_ticks():
tick.label.set_fontsize(14)
plt.suptitle(title, fontsize=16)
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(fig)
return fig
if min_dzi == 0:
ax2.set_zticks([0])
eps = 1e-9
ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi)+eps)
else:
ax2.axes.set_zlim3d(auto=True)
ax2.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
ax2.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
ax2.set_zlabel("Imag[rho]", fontsize=14)
for tick in ax2.zaxis.get_major_ticks():
tick.label.set_fontsize(14)
fig.suptitle(title, fontsize=16)
if ax_real is None and ax_imag is None:
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(fig)
return fig
def plot_state_paulivec(rho, title="", figsize=None, color=None):
def plot_state_paulivec(rho, title="", figsize=None, color=None, ax=None):
"""Plot the paulivec representation of a quantum state.
Plot a bargraph of the mixed state rho over the pauli matrices
@ -369,8 +423,14 @@ def plot_state_paulivec(rho, title="", figsize=None, color=None):
title (str): a string that represents the plot title
figsize (tuple): Figure size in inches.
color (list or str): Color of the expectation value bars.
ax (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. Additionally, if specified there
will be no returned Figure since it is redundant.
Returns:
matplotlib.Figure: The matplotlib.Figure of the visualization
matplotlib.Figure: The matplotlib.Figure of the visualization if the
``ax`` kwarg is not set
Raises:
ImportError: Requires matplotlib.
"""
@ -389,7 +449,12 @@ def plot_state_paulivec(rho, title="", figsize=None, color=None):
ind = np.arange(numelem) # the x locations for the groups
width = 0.5 # the width of the bars
fig, ax = plt.subplots(figsize=figsize)
if ax is None:
return_fig = True
fig, ax = plt.subplots(figsize=figsize)
else:
return_fig = False
fig = ax.get_figure()
ax.grid(zorder=0, linewidth=1, linestyle='--')
ax.bar(ind, values, width, color=color, zorder=2)
ax.axhline(linewidth=1, color='k')
@ -404,10 +469,11 @@ def plot_state_paulivec(rho, title="", figsize=None, color=None):
for tick in ax.xaxis.get_major_ticks()+ax.yaxis.get_major_ticks():
tick.label.set_fontsize(14)
ax.set_title(title, fontsize=16)
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(fig)
return fig
if return_fig:
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(fig)
return fig
def n_choose_k(n, k):
@ -469,7 +535,7 @@ def phase_to_rgb(complex_number):
return rgb
def plot_state_qsphere(rho, figsize=None):
def plot_state_qsphere(rho, figsize=None, ax=None):
"""Plot the qsphere representation of a quantum state.
Here, the size of the points is proportional to the probability
of the corresponding term in the state and the color represents
@ -479,9 +545,13 @@ def plot_state_qsphere(rho, figsize=None):
rho (ndarray): State vector or density matrix representation.
of quantum state.
figsize (tuple): Figure size in inches.
ax (matplotlib.axes.Axes): An optional Axes object to be used for
the visualization output. If none is specified a new matplotlib
Figure will be created and used. Additionally, if specified there
will be no returned Figure since it is redundant.
Returns:
Figure: A matplotlib figure instance.
Figure: A matplotlib figure instance if the ``ax`` kwag is not set
Raises:
ImportError: Requires matplotlib.
@ -501,7 +571,13 @@ def plot_state_qsphere(rho, figsize=None):
# get the eigenvectors and eigenvalues
we, stateall = linalg.eigh(rho)
fig = plt.figure(figsize=figsize)
if ax is None:
return_fig = True
fig = plt.figure(figsize=figsize)
else:
return_fig = False
fig = ax.get_figure()
gs = gridspec.GridSpec(nrows=3, ncols=3)
ax = fig.add_subplot(gs[0:3, 0:3], projection='3d')
@ -635,10 +711,11 @@ def plot_state_qsphere(rho, figsize=None):
ax2.text(0, -offset, r'$3\pi/2$', horizontalalignment='center',
verticalalignment='center', fontsize=14)
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(fig)
return fig
if return_fig:
if get_backend() in ['module://ipykernel.pylab.backend_inline',
'nbAgg']:
plt.close(fig)
return fig
def generate_facecolors(x, y, z, dx, dy, dz, color):

View File

@ -0,0 +1,37 @@
---
features:
- |
An ``ax`` kwarg has been added to the following visualization functions:
* ``qiskit.visualization.plot_histogram``
* ``qiskit.visualization.plot_state_paulivec``
* ``qiskit.visualization.plot_state_qsphere``
* ``qiskit.visualization.circuit_drawer`` (``mpl`` backend only)
* ``qiskit.QuantumCircuit.draw`` (``mpl`` backend only)
This kwarg is used to pass in a ``matplotlib.axes.Axes`` object to the
visualization functions. This enables integrating these visualization
functions into a larger visualization workflow. Also, if an `ax` kwarg is
specified then there is no return from the visualization functions.
- |
An ``ax_real`` and ``ax_imag`` kwarg has been added to the
following visualization functions:
* ``qiskit.visualization.plot_state_hinton``
* ``qiskit.visualization.plot_state_city``
These new kargs work the same as the newly added ``ax`` kwargs for other
visualization functions. However because these plots use two axes (one for
the real component, the other for the imaginary component). Having two
kwargs also provides the flexibility to only generate a visualization for
one of the components instead of always doing both. For example::
from matplotlib import pyplot as plt
from qiskit.visualization import plot_state_hinton
ax = plt.gca()
plot_state_hinton(psi, ax_real=ax)
will only generate a plot of the real component.