mirror of https://github.com/Qiskit/qiskit.git
Add ax kwarg to matplotlib visualization functions (#3053)
* Add ax kwarg to matplotlib visualization functions
This commit adds support for passing a matplotlib axes to visualization
functions. This enables integrating the qiskit visualizations into a
larger matplotlib visualization. When an ax kwarg is set that input will
be used for the output visualization and a new figure will not be
generated or returned.
Fixes #1950
* typo in release note
* Add ax kwarg to matplotlib visualization functions
This commit adds support for passing a matplotlib axes to visualization
functions. This enables integrating the qiskit visualizations into a
larger matplotlib visualization. When an ax kwarg is set that input will
be used for the output visualization and a new figure will not be
generated or returned.
Fixes #1950
* Fix lint and typos
* pylint: disable=invalid-name
* Fix lost lines from rebase
As part of rebasing and cleaning up the branch from several failed merge
attempts a few lines from plot_state_city() were lost. This commit fixes
that mistake and restores the z axis labels which were missing from the
imaginary component subplot.
* Fix rebase typo
* Revert "pylint: disable=invalid-name"
This reverts commit 5566611719
.
This commit is contained in:
parent
1cb65623b4
commit
352f940dfa
|
@ -172,7 +172,7 @@ attr-rgx=[a-z_][a-z0-9_]{2,30}$
|
|||
attr-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=[a-z_][a-z0-9_]{2,30}$
|
||||
argument-rgx=[a-z_][a-z0-9_]{2,30}|ax$
|
||||
|
||||
# Naming hint for argument names
|
||||
argument-name-hint=[a-z_][a-z0-9_]{2,30}$
|
||||
|
|
|
@ -560,8 +560,8 @@ class QuantumCircuit:
|
|||
|
||||
def draw(self, scale=0.7, filename=None, style=None, output=None,
|
||||
interactive=False, line_length=None, plot_barriers=True,
|
||||
reverse_bits=False, justify=None, idle_wires=True, vertical_compression='medium',
|
||||
with_layout=True, fold=None):
|
||||
reverse_bits=False, justify=None, vertical_compression='medium', idle_wires=True,
|
||||
with_layout=True, fold=None, ax=None):
|
||||
"""Draw the quantum circuit
|
||||
|
||||
Using the output parameter you can specify the format. The choices are:
|
||||
|
@ -610,12 +610,19 @@ class QuantumCircuit:
|
|||
guess the console width using `shutil.get_terminal_size()`. However, if
|
||||
running in jupyter, the default line length is set to 80 characters.
|
||||
In `mpl` is the amount of operations before folding. Default is 25.
|
||||
ax (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. Additionally, if specified there
|
||||
will be no returned Figure since it is redundant. This is only used
|
||||
when the ``output`` kwarg is set to use the ``mpl`` backend. It
|
||||
will be silently ignored with all other outputs.
|
||||
|
||||
Returns:
|
||||
PIL.Image or matplotlib.figure or str or TextDrawing:
|
||||
* PIL.Image: (output `latex`) an in-memory representation of the
|
||||
image of the circuit diagram.
|
||||
* matplotlib.figure: (output `mpl`) a matplotlib figure object
|
||||
for the circuit diagram.
|
||||
for the circuit diagram, if the ``ax`` kwarg is not set.
|
||||
* str: (output `latex_source`). The LaTeX source code.
|
||||
* TextDrawing: (output `text`). A drawing that can be printed as
|
||||
ascii art
|
||||
|
@ -635,7 +642,9 @@ class QuantumCircuit:
|
|||
justify=justify,
|
||||
vertical_compression=vertical_compression,
|
||||
idle_wires=idle_wires,
|
||||
with_layout=with_layout, fold=fold)
|
||||
with_layout=with_layout,
|
||||
fold=fold,
|
||||
ax=ax)
|
||||
|
||||
def size(self):
|
||||
"""Returns total number of gate operations in circuit.
|
||||
|
|
|
@ -56,7 +56,8 @@ def circuit_drawer(circuit,
|
|||
vertical_compression='medium',
|
||||
idle_wires=True,
|
||||
with_layout=True,
|
||||
fold=None):
|
||||
fold=None,
|
||||
ax=None):
|
||||
"""Draw a quantum circuit to different formats (set by output parameter):
|
||||
0. text: ASCII art TextDrawing that can be printed in the console.
|
||||
1. latex: high-quality images, but heavy external software dependencies
|
||||
|
@ -103,11 +104,17 @@ def circuit_drawer(circuit,
|
|||
guess the console width using `shutil.get_terminal_size()`. However, if
|
||||
running in jupyter, the default line length is set to 80 characters.
|
||||
In `mpl` is the amount of operations before folding. Default is 25.
|
||||
ax (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. Additionally, if specified there
|
||||
will be no returned Figure since it is redundant. This is only used
|
||||
when the ``output`` kwarg is set to use the ``mpl`` backend. It
|
||||
will be silently ignored with all other outputs.
|
||||
Returns:
|
||||
PIL.Image: (output `latex`) an in-memory representation of the image
|
||||
of the circuit diagram.
|
||||
matplotlib.figure: (output `mpl`) a matplotlib figure object for the
|
||||
circuit diagram.
|
||||
circuit diagram, if the ``ax`` kwarg is not set
|
||||
String: (output `latex_source`). The LaTeX source code.
|
||||
TextDrawing: (output `text`). A drawing that can be printed as ascii art
|
||||
Raises:
|
||||
|
@ -257,7 +264,8 @@ def circuit_drawer(circuit,
|
|||
justify=justify,
|
||||
idle_wires=idle_wires,
|
||||
with_layout=with_layout,
|
||||
fold=fold)
|
||||
fold=fold,
|
||||
ax=ax)
|
||||
else:
|
||||
raise exceptions.VisualizationError(
|
||||
'Invalid output type %s selected. The only valid choices '
|
||||
|
@ -538,7 +546,8 @@ def _matplotlib_circuit_drawer(circuit,
|
|||
justify=None,
|
||||
idle_wires=True,
|
||||
with_layout=True,
|
||||
fold=None):
|
||||
fold=None,
|
||||
ax=None):
|
||||
"""Draw a quantum circuit based on matplotlib.
|
||||
If `%matplotlib inline` is invoked in a Jupyter notebook, it visualizes a circuit inline.
|
||||
We recommend `%config InlineBackend.figure_format = 'svg'` for the inline visualization.
|
||||
|
@ -558,8 +567,14 @@ def _matplotlib_circuit_drawer(circuit,
|
|||
with_layout (bool): Include layout information, with labels on the physical
|
||||
layout. Default: True.
|
||||
fold (int): amount ops allowed before folding. Default is 25.
|
||||
ax (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. Additionally, if specified there
|
||||
will be no returned Figure since it is redundant.
|
||||
|
||||
Returns:
|
||||
matplotlib.figure: a matplotlib figure object for the circuit diagram
|
||||
if the ``ax`` kwarg is not set.
|
||||
"""
|
||||
|
||||
qregs, cregs, ops = utils._get_layered_instructions(circuit,
|
||||
|
@ -576,5 +591,6 @@ def _matplotlib_circuit_drawer(circuit,
|
|||
|
||||
qcd = _matplotlib.MatplotlibDrawer(qregs, cregs, ops, scale=scale, style=style,
|
||||
plot_barriers=plot_barriers,
|
||||
reverse_bits=reverse_bits, layout=layout, fold=fold)
|
||||
reverse_bits=reverse_bits, layout=layout,
|
||||
fold=fold, ax=ax)
|
||||
return qcd.draw(filename)
|
||||
|
|
|
@ -52,7 +52,7 @@ DIST_MEAS = {'hamming': hamming_distance}
|
|||
|
||||
def plot_histogram(data, figsize=(7, 5), color=None, number_to_keep=None,
|
||||
sort='asc', target_string=None,
|
||||
legend=None, bar_labels=True, title=None):
|
||||
legend=None, bar_labels=True, title=None, ax=None):
|
||||
"""Plot a histogram of data.
|
||||
|
||||
Args:
|
||||
|
@ -69,9 +69,14 @@ def plot_histogram(data, figsize=(7, 5), color=None, number_to_keep=None,
|
|||
list or 1 if it's a dict)
|
||||
bar_labels (bool): Label each bar in histogram with probability value.
|
||||
title (str): A string to use for the plot title
|
||||
ax (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. Additionally, if specified there
|
||||
will be no returned Figure since it is redundant.
|
||||
|
||||
Returns:
|
||||
matplotlib.Figure: A figure for the rendered histogram.
|
||||
matplotlib.Figure: A figure for the rendered histogram, if the ``ax``
|
||||
kwarg is not set.
|
||||
|
||||
Raises:
|
||||
ImportError: Matplotlib not available.
|
||||
|
@ -95,8 +100,11 @@ def plot_histogram(data, figsize=(7, 5), color=None, number_to_keep=None,
|
|||
raise VisualizationError("Length of legendL (%s) doesn't match "
|
||||
"number of input executions: %s" %
|
||||
(len(legend), len(data)))
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
else:
|
||||
fig = None
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
labels = list(sorted(
|
||||
functools.reduce(lambda x, y: x.union(y.keys()), data, set())))
|
||||
if number_to_keep is not None:
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# copyright notice, and modified files need to carry a notice indicating
|
||||
# that they have been altered from the originals.
|
||||
|
||||
# pylint: disable=invalid-name,missing-docstring
|
||||
# pylint: disable=invalid-name,missing-docstring,inconsistent-return-statements
|
||||
|
||||
"""mpl circuit visualization backend."""
|
||||
|
||||
|
@ -103,7 +103,7 @@ class Anchor:
|
|||
class MatplotlibDrawer:
|
||||
def __init__(self, qregs, cregs, ops,
|
||||
scale=1.0, style=None, plot_barriers=True,
|
||||
reverse_bits=False, layout=None, fold=25):
|
||||
reverse_bits=False, layout=None, fold=25, ax=None):
|
||||
|
||||
if not HAS_MATPLOTLIB:
|
||||
raise ImportError('The class MatplotlibDrawer needs matplotlib. '
|
||||
|
@ -145,14 +145,21 @@ class MatplotlibDrawer:
|
|||
with open(style, 'r') as infile:
|
||||
dic = json.load(infile)
|
||||
self._style.set_style(dic)
|
||||
if ax is None:
|
||||
self.return_fig = True
|
||||
self.figure = plt.figure()
|
||||
self.figure.patch.set_facecolor(color=self._style.bg)
|
||||
self.ax = self.figure.add_subplot(111)
|
||||
else:
|
||||
self.return_fig = False
|
||||
self.ax = ax
|
||||
self.figure = ax.get_figure()
|
||||
|
||||
self.fold = self._style.fold or fold # self._style.fold should be removed after 0.10
|
||||
# TODO: self._style.fold should be removed after deprecation
|
||||
self.fold = self._style.fold or fold
|
||||
if self.fold < 2:
|
||||
self.fold = -1
|
||||
|
||||
self.figure = plt.figure()
|
||||
self.figure.patch.set_facecolor(color=self._style.bg)
|
||||
self.ax = self.figure.add_subplot(111)
|
||||
self.ax.axis('off')
|
||||
self.ax.set_aspect('equal')
|
||||
self.ax.tick_params(labelbottom=False, labeltop=False,
|
||||
|
@ -503,10 +510,11 @@ class MatplotlibDrawer:
|
|||
if filename:
|
||||
self.figure.savefig(filename, dpi=self._style.dpi,
|
||||
bbox_inches='tight')
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(self.figure)
|
||||
return self.figure
|
||||
if self.return_fig:
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(self.figure)
|
||||
return self.figure
|
||||
|
||||
def _draw_regs(self):
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# that they have been altered from the originals.
|
||||
|
||||
# pylint: disable=invalid-name,ungrouped-imports,import-error
|
||||
# pylint: disable=inconsistent-return-statements
|
||||
|
||||
"""
|
||||
Visualization functions for quantum states.
|
||||
|
@ -58,15 +59,29 @@ if HAS_MATPLOTLIB:
|
|||
FancyArrowPatch.draw(self, renderer)
|
||||
|
||||
|
||||
def plot_state_hinton(rho, title='', figsize=None):
|
||||
def plot_state_hinton(rho, title='', figsize=None, ax_real=None, ax_imag=None):
|
||||
"""Plot a hinton diagram for the quanum state.
|
||||
|
||||
Args:
|
||||
rho (ndarray): Numpy array for state vector or density matrix.
|
||||
title (str): a string that represents the plot title
|
||||
figsize (tuple): Figure size in inches.
|
||||
ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. If this is specified without an
|
||||
ax_imag only the real component plot will be generated.
|
||||
Additionally, if specified there will be no returned Figure since
|
||||
it is redundant.
|
||||
ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. If this is specified without an
|
||||
ax_imag only the real component plot will be generated.
|
||||
Additionally, if specified there will be no returned Figure since
|
||||
it is redundant.
|
||||
|
||||
Returns:
|
||||
matplotlib.Figure: The matplotlib.Figure of the visualization
|
||||
matplotlib.Figure: The matplotlib.Figure of the visualization if
|
||||
neither ax_real or ax_imag is set.
|
||||
|
||||
Raises:
|
||||
ImportError: Requires matplotlib.
|
||||
|
@ -77,7 +92,15 @@ def plot_state_hinton(rho, title='', figsize=None):
|
|||
if figsize is None:
|
||||
figsize = (8, 5)
|
||||
num = int(np.log2(len(rho)))
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
||||
if not ax_real and not ax_imag:
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
||||
else:
|
||||
if ax_real:
|
||||
fig = ax_real.get_figure()
|
||||
else:
|
||||
fig = ax_imag.get_figure()
|
||||
ax1 = ax_real
|
||||
ax2 = ax_imag
|
||||
max_weight = 2 ** np.ceil(np.log(np.abs(rho).max()) / np.log(2))
|
||||
datareal = np.real(rho)
|
||||
dataimag = np.imag(rho)
|
||||
|
@ -86,52 +109,55 @@ def plot_state_hinton(rho, title='', figsize=None):
|
|||
lx = len(datareal[0]) # Work out matrix dimensions
|
||||
ly = len(datareal[:, 0])
|
||||
# Real
|
||||
ax1.patch.set_facecolor('gray')
|
||||
ax1.set_aspect('equal', 'box')
|
||||
ax1.xaxis.set_major_locator(plt.NullLocator())
|
||||
ax1.yaxis.set_major_locator(plt.NullLocator())
|
||||
if ax1:
|
||||
ax1.patch.set_facecolor('gray')
|
||||
ax1.set_aspect('equal', 'box')
|
||||
ax1.xaxis.set_major_locator(plt.NullLocator())
|
||||
ax1.yaxis.set_major_locator(plt.NullLocator())
|
||||
|
||||
for (x, y), w in np.ndenumerate(datareal):
|
||||
color = 'white' if w > 0 else 'black'
|
||||
size = np.sqrt(np.abs(w) / max_weight)
|
||||
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
|
||||
facecolor=color, edgecolor=color)
|
||||
ax1.add_patch(rect)
|
||||
for (x, y), w in np.ndenumerate(datareal):
|
||||
color = 'white' if w > 0 else 'black'
|
||||
size = np.sqrt(np.abs(w) / max_weight)
|
||||
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
|
||||
facecolor=color, edgecolor=color)
|
||||
ax1.add_patch(rect)
|
||||
|
||||
ax1.set_xticks(np.arange(0, lx+0.5, 1))
|
||||
ax1.set_yticks(np.arange(0, ly+0.5, 1))
|
||||
ax1.set_yticklabels(row_names, fontsize=14)
|
||||
ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
|
||||
ax1.autoscale_view()
|
||||
ax1.invert_yaxis()
|
||||
ax1.set_title('Re[$\\rho$]', fontsize=14)
|
||||
ax1.set_xticks(np.arange(0, lx+0.5, 1))
|
||||
ax1.set_yticks(np.arange(0, ly+0.5, 1))
|
||||
ax1.set_yticklabels(row_names, fontsize=14)
|
||||
ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
|
||||
ax1.autoscale_view()
|
||||
ax1.invert_yaxis()
|
||||
ax1.set_title('Re[$\\rho$]', fontsize=14)
|
||||
# Imaginary
|
||||
ax2.patch.set_facecolor('gray')
|
||||
ax2.set_aspect('equal', 'box')
|
||||
ax2.xaxis.set_major_locator(plt.NullLocator())
|
||||
ax2.yaxis.set_major_locator(plt.NullLocator())
|
||||
if ax2:
|
||||
ax2.patch.set_facecolor('gray')
|
||||
ax2.set_aspect('equal', 'box')
|
||||
ax2.xaxis.set_major_locator(plt.NullLocator())
|
||||
ax2.yaxis.set_major_locator(plt.NullLocator())
|
||||
|
||||
for (x, y), w in np.ndenumerate(dataimag):
|
||||
color = 'white' if w > 0 else 'black'
|
||||
size = np.sqrt(np.abs(w) / max_weight)
|
||||
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
|
||||
facecolor=color, edgecolor=color)
|
||||
ax2.add_patch(rect)
|
||||
for (x, y), w in np.ndenumerate(dataimag):
|
||||
color = 'white' if w > 0 else 'black'
|
||||
size = np.sqrt(np.abs(w) / max_weight)
|
||||
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
|
||||
facecolor=color, edgecolor=color)
|
||||
ax2.add_patch(rect)
|
||||
|
||||
ax2.set_xticks(np.arange(0, lx+0.5, 1))
|
||||
ax2.set_yticks(np.arange(0, ly+0.5, 1))
|
||||
ax2.set_yticklabels(row_names, fontsize=14)
|
||||
ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
|
||||
ax2.set_xticks(np.arange(0, lx+0.5, 1))
|
||||
ax2.set_yticks(np.arange(0, ly+0.5, 1))
|
||||
ax2.set_yticklabels(row_names, fontsize=14)
|
||||
ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
|
||||
|
||||
ax2.autoscale_view()
|
||||
ax2.invert_yaxis()
|
||||
ax2.set_title('Im[$\\rho$]', fontsize=14)
|
||||
ax2.autoscale_view()
|
||||
ax2.invert_yaxis()
|
||||
ax2.set_title('Im[$\\rho$]', fontsize=14)
|
||||
if title:
|
||||
fig.suptitle(title, fontsize=16)
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(fig)
|
||||
return fig
|
||||
if ax_real is None and ax_imag is None:
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(fig)
|
||||
return fig
|
||||
|
||||
|
||||
def plot_bloch_vector(bloch, title="", ax=None, figsize=None):
|
||||
|
@ -142,7 +168,8 @@ def plot_bloch_vector(bloch, title="", ax=None, figsize=None):
|
|||
Args:
|
||||
bloch (list[double]): array of three elements where [<x>, <y>, <z>]
|
||||
title (str): a string that represents the plot title
|
||||
ax (matplotlib.Axes): An Axes to use for rendering the bloch sphere
|
||||
ax (matplotlib.axes.Axes): An Axes to use for rendering the bloch
|
||||
sphere
|
||||
figsize (tuple): Figure size in inches. Has no effect is passing `ax`.
|
||||
|
||||
Returns:
|
||||
|
@ -179,7 +206,7 @@ def plot_bloch_multivector(rho, title='', figsize=None):
|
|||
figsize (tuple): Has no effect, here for compatibility only.
|
||||
|
||||
Returns:
|
||||
Figure: A matplotlib figure instance if `ax = None`.
|
||||
Figure: A matplotlib figure instance.
|
||||
|
||||
Raises:
|
||||
ImportError: Requires matplotlib.
|
||||
|
@ -210,7 +237,7 @@ def plot_bloch_multivector(rho, title='', figsize=None):
|
|||
|
||||
|
||||
def plot_state_city(rho, title="", figsize=None, color=None,
|
||||
alpha=1):
|
||||
alpha=1, ax_real=None, ax_imag=None):
|
||||
"""Plot the cityscape of quantum state.
|
||||
|
||||
Plot two 3d bar graphs (two dimensional) of the real and imaginary
|
||||
|
@ -223,8 +250,22 @@ def plot_state_city(rho, title="", figsize=None, color=None,
|
|||
color (list): A list of len=2 giving colors for real and
|
||||
imaginary components of matrix elements.
|
||||
alpha (float): Transparency value for bars
|
||||
ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. If this is specified without an
|
||||
ax_imag only the real component plot will be generated.
|
||||
Additionally, if specified there will be no returned Figure since
|
||||
it is redundant.
|
||||
ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. If this is specified without an
|
||||
ax_imag only the real component plot will be generated.
|
||||
Additionally, if specified there will be no returned Figure since
|
||||
it is redundant.
|
||||
|
||||
Returns:
|
||||
matplotlib.Figure: The matplotlib.Figure of the visualization
|
||||
matplotlib.Figure: The matplotlib.Figure of the visualization if the
|
||||
``ax_real`` and ``ax_imag`` kwargs are not set
|
||||
|
||||
Raises:
|
||||
ImportError: Requires matplotlib.
|
||||
|
@ -267,99 +308,112 @@ def plot_state_city(rho, title="", figsize=None, color=None,
|
|||
color[0] = "#648fff"
|
||||
if color[1] is None:
|
||||
color[1] = "#648fff"
|
||||
if ax_real is None and ax_imag is None:
|
||||
# set default figure size
|
||||
if figsize is None:
|
||||
figsize = (15, 5)
|
||||
|
||||
# set default figure size
|
||||
if figsize is None:
|
||||
figsize = (15, 5)
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
|
||||
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
|
||||
elif ax_real is not None:
|
||||
fig = ax_real.get_figure()
|
||||
ax1 = ax_real
|
||||
if ax_imag is not None:
|
||||
ax2 = ax_imag
|
||||
else:
|
||||
fig = ax_imag.get_figure()
|
||||
ax1 = None
|
||||
ax2 = ax_imag
|
||||
|
||||
x = [0, max(xpos)+0.5, max(xpos)+0.5, 0]
|
||||
y = [0, 0, max(ypos)+0.5, max(ypos)+0.5]
|
||||
z = [0, 0, 0, 0]
|
||||
verts = [list(zip(x, y, z))]
|
||||
|
||||
fc1 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzr, color[0])
|
||||
for idx, cur_zpos in enumerate(zpos):
|
||||
if dzr[idx] > 0:
|
||||
zorder = 2
|
||||
else:
|
||||
zorder = 0
|
||||
b1 = ax1.bar3d(xpos[idx], ypos[idx], cur_zpos,
|
||||
dx[idx], dy[idx], dzr[idx],
|
||||
alpha=alpha, zorder=zorder)
|
||||
b1.set_facecolors(fc1[6*idx:6*idx+6])
|
||||
if ax1 is not None:
|
||||
fc1 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzr, color[0])
|
||||
for idx, cur_zpos in enumerate(zpos):
|
||||
if dzr[idx] > 0:
|
||||
zorder = 2
|
||||
else:
|
||||
zorder = 0
|
||||
b1 = ax1.bar3d(xpos[idx], ypos[idx], cur_zpos,
|
||||
dx[idx], dy[idx], dzr[idx],
|
||||
alpha=alpha, zorder=zorder)
|
||||
b1.set_facecolors(fc1[6*idx:6*idx+6])
|
||||
|
||||
pc1 = Poly3DCollection(verts, alpha=0.15, facecolor='k',
|
||||
linewidths=1, zorder=1)
|
||||
pc1 = Poly3DCollection(verts, alpha=0.15, facecolor='k',
|
||||
linewidths=1, zorder=1)
|
||||
|
||||
if min(dzr) < 0 < max(dzr):
|
||||
ax1.add_collection3d(pc1)
|
||||
if min(dzr) < 0 < max(dzr):
|
||||
ax1.add_collection3d(pc1)
|
||||
|
||||
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
|
||||
fc2 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzi, color[1])
|
||||
for idx, cur_zpos in enumerate(zpos):
|
||||
if dzi[idx] > 0:
|
||||
zorder = 2
|
||||
else:
|
||||
zorder = 0
|
||||
b2 = ax2.bar3d(xpos[idx], ypos[idx], cur_zpos,
|
||||
dx[idx], dy[idx], dzi[idx],
|
||||
alpha=alpha, zorder=zorder)
|
||||
b2.set_facecolors(fc2[6*idx:6*idx+6])
|
||||
if ax2 is not None:
|
||||
fc2 = generate_facecolors(xpos, ypos, zpos, dx, dy, dzi, color[1])
|
||||
for idx, cur_zpos in enumerate(zpos):
|
||||
if dzi[idx] > 0:
|
||||
zorder = 2
|
||||
else:
|
||||
zorder = 0
|
||||
b2 = ax2.bar3d(xpos[idx], ypos[idx], cur_zpos,
|
||||
dx[idx], dy[idx], dzi[idx],
|
||||
alpha=alpha, zorder=zorder)
|
||||
b2.set_facecolors(fc2[6*idx:6*idx+6])
|
||||
|
||||
pc2 = Poly3DCollection(verts, alpha=0.2, facecolor='k',
|
||||
linewidths=1, zorder=1)
|
||||
|
||||
if min(dzi) < 0 < max(dzi):
|
||||
ax2.add_collection3d(pc2)
|
||||
ax1.set_xticks(np.arange(0.5, lx+0.5, 1))
|
||||
ax1.set_yticks(np.arange(0.5, ly+0.5, 1))
|
||||
max_dzr = max(dzr)
|
||||
min_dzr = min(dzr)
|
||||
if max_dzr != min_dzr:
|
||||
ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr)+1e-9)
|
||||
else:
|
||||
if min_dzr == 0:
|
||||
pc2 = Poly3DCollection(verts, alpha=0.2, facecolor='k',
|
||||
linewidths=1, zorder=1)
|
||||
if min(dzi) < 0 < max(dzi):
|
||||
ax2.add_collection3d(pc2)
|
||||
if ax1 is not None:
|
||||
ax1.set_xticks(np.arange(0.5, lx+0.5, 1))
|
||||
ax1.set_yticks(np.arange(0.5, ly+0.5, 1))
|
||||
max_dzr = max(dzr)
|
||||
min_dzr = min(dzr)
|
||||
if max_dzr != min_dzr:
|
||||
ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr)+1e-9)
|
||||
else:
|
||||
ax1.axes.set_zlim3d(auto=True)
|
||||
ax1.zaxis.set_major_locator(MaxNLocator(5))
|
||||
ax1.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
|
||||
ax1.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
|
||||
ax1.set_zlabel("Real[rho]", fontsize=14)
|
||||
for tick in ax1.zaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(14)
|
||||
|
||||
ax2.set_xticks(np.arange(0.5, lx+0.5, 1))
|
||||
ax2.set_yticks(np.arange(0.5, ly+0.5, 1))
|
||||
min_dzi = np.min(dzi)
|
||||
max_dzi = np.max(dzi)
|
||||
if min_dzi != max_dzi:
|
||||
eps = 0
|
||||
ax2.zaxis.set_major_locator(MaxNLocator(5))
|
||||
ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi)+eps)
|
||||
else:
|
||||
if min_dzi == 0:
|
||||
ax2.set_zticks([0])
|
||||
eps = 1e-9
|
||||
if min_dzr == 0:
|
||||
ax1.axes.set_zlim3d(np.min(dzr), np.max(dzr)+1e-9)
|
||||
else:
|
||||
ax1.axes.set_zlim3d(auto=True)
|
||||
ax1.zaxis.set_major_locator(MaxNLocator(5))
|
||||
ax1.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
|
||||
ax1.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
|
||||
ax1.set_zlabel("Real[rho]", fontsize=14)
|
||||
for tick in ax1.zaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(14)
|
||||
if ax2 is not None:
|
||||
ax2.set_xticks(np.arange(0.5, lx+0.5, 1))
|
||||
ax2.set_yticks(np.arange(0.5, ly+0.5, 1))
|
||||
min_dzi = np.min(dzi)
|
||||
max_dzi = np.max(dzi)
|
||||
if min_dzi != max_dzi:
|
||||
eps = 0
|
||||
ax2.zaxis.set_major_locator(MaxNLocator(5))
|
||||
ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi)+eps)
|
||||
else:
|
||||
ax2.axes.set_zlim3d(auto=True)
|
||||
ax2.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
|
||||
ax2.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
|
||||
ax2.set_zlabel("Imag[rho]", fontsize=14)
|
||||
for tick in ax2.zaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(14)
|
||||
plt.suptitle(title, fontsize=16)
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(fig)
|
||||
return fig
|
||||
if min_dzi == 0:
|
||||
ax2.set_zticks([0])
|
||||
eps = 1e-9
|
||||
ax2.axes.set_zlim3d(np.min(dzi), np.max(dzi)+eps)
|
||||
else:
|
||||
ax2.axes.set_zlim3d(auto=True)
|
||||
ax2.w_xaxis.set_ticklabels(row_names, fontsize=14, rotation=45)
|
||||
ax2.w_yaxis.set_ticklabels(column_names, fontsize=14, rotation=-22.5)
|
||||
ax2.set_zlabel("Imag[rho]", fontsize=14)
|
||||
for tick in ax2.zaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(14)
|
||||
fig.suptitle(title, fontsize=16)
|
||||
if ax_real is None and ax_imag is None:
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(fig)
|
||||
return fig
|
||||
|
||||
|
||||
def plot_state_paulivec(rho, title="", figsize=None, color=None):
|
||||
def plot_state_paulivec(rho, title="", figsize=None, color=None, ax=None):
|
||||
"""Plot the paulivec representation of a quantum state.
|
||||
|
||||
Plot a bargraph of the mixed state rho over the pauli matrices
|
||||
|
@ -369,8 +423,14 @@ def plot_state_paulivec(rho, title="", figsize=None, color=None):
|
|||
title (str): a string that represents the plot title
|
||||
figsize (tuple): Figure size in inches.
|
||||
color (list or str): Color of the expectation value bars.
|
||||
ax (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. Additionally, if specified there
|
||||
will be no returned Figure since it is redundant.
|
||||
|
||||
Returns:
|
||||
matplotlib.Figure: The matplotlib.Figure of the visualization
|
||||
matplotlib.Figure: The matplotlib.Figure of the visualization if the
|
||||
``ax`` kwarg is not set
|
||||
Raises:
|
||||
ImportError: Requires matplotlib.
|
||||
"""
|
||||
|
@ -389,7 +449,12 @@ def plot_state_paulivec(rho, title="", figsize=None, color=None):
|
|||
|
||||
ind = np.arange(numelem) # the x locations for the groups
|
||||
width = 0.5 # the width of the bars
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
if ax is None:
|
||||
return_fig = True
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
else:
|
||||
return_fig = False
|
||||
fig = ax.get_figure()
|
||||
ax.grid(zorder=0, linewidth=1, linestyle='--')
|
||||
ax.bar(ind, values, width, color=color, zorder=2)
|
||||
ax.axhline(linewidth=1, color='k')
|
||||
|
@ -404,10 +469,11 @@ def plot_state_paulivec(rho, title="", figsize=None, color=None):
|
|||
for tick in ax.xaxis.get_major_ticks()+ax.yaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(14)
|
||||
ax.set_title(title, fontsize=16)
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(fig)
|
||||
return fig
|
||||
if return_fig:
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(fig)
|
||||
return fig
|
||||
|
||||
|
||||
def n_choose_k(n, k):
|
||||
|
@ -469,7 +535,7 @@ def phase_to_rgb(complex_number):
|
|||
return rgb
|
||||
|
||||
|
||||
def plot_state_qsphere(rho, figsize=None):
|
||||
def plot_state_qsphere(rho, figsize=None, ax=None):
|
||||
"""Plot the qsphere representation of a quantum state.
|
||||
Here, the size of the points is proportional to the probability
|
||||
of the corresponding term in the state and the color represents
|
||||
|
@ -479,9 +545,13 @@ def plot_state_qsphere(rho, figsize=None):
|
|||
rho (ndarray): State vector or density matrix representation.
|
||||
of quantum state.
|
||||
figsize (tuple): Figure size in inches.
|
||||
ax (matplotlib.axes.Axes): An optional Axes object to be used for
|
||||
the visualization output. If none is specified a new matplotlib
|
||||
Figure will be created and used. Additionally, if specified there
|
||||
will be no returned Figure since it is redundant.
|
||||
|
||||
Returns:
|
||||
Figure: A matplotlib figure instance.
|
||||
Figure: A matplotlib figure instance if the ``ax`` kwag is not set
|
||||
|
||||
Raises:
|
||||
ImportError: Requires matplotlib.
|
||||
|
@ -501,7 +571,13 @@ def plot_state_qsphere(rho, figsize=None):
|
|||
# get the eigenvectors and eigenvalues
|
||||
we, stateall = linalg.eigh(rho)
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
if ax is None:
|
||||
return_fig = True
|
||||
fig = plt.figure(figsize=figsize)
|
||||
else:
|
||||
return_fig = False
|
||||
fig = ax.get_figure()
|
||||
|
||||
gs = gridspec.GridSpec(nrows=3, ncols=3)
|
||||
|
||||
ax = fig.add_subplot(gs[0:3, 0:3], projection='3d')
|
||||
|
@ -635,10 +711,11 @@ def plot_state_qsphere(rho, figsize=None):
|
|||
ax2.text(0, -offset, r'$3\pi/2$', horizontalalignment='center',
|
||||
verticalalignment='center', fontsize=14)
|
||||
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(fig)
|
||||
return fig
|
||||
if return_fig:
|
||||
if get_backend() in ['module://ipykernel.pylab.backend_inline',
|
||||
'nbAgg']:
|
||||
plt.close(fig)
|
||||
return fig
|
||||
|
||||
|
||||
def generate_facecolors(x, y, z, dx, dy, dz, color):
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
---
|
||||
features:
|
||||
- |
|
||||
An ``ax`` kwarg has been added to the following visualization functions:
|
||||
|
||||
* ``qiskit.visualization.plot_histogram``
|
||||
* ``qiskit.visualization.plot_state_paulivec``
|
||||
* ``qiskit.visualization.plot_state_qsphere``
|
||||
* ``qiskit.visualization.circuit_drawer`` (``mpl`` backend only)
|
||||
* ``qiskit.QuantumCircuit.draw`` (``mpl`` backend only)
|
||||
|
||||
This kwarg is used to pass in a ``matplotlib.axes.Axes`` object to the
|
||||
visualization functions. This enables integrating these visualization
|
||||
functions into a larger visualization workflow. Also, if an `ax` kwarg is
|
||||
specified then there is no return from the visualization functions.
|
||||
|
||||
- |
|
||||
An ``ax_real`` and ``ax_imag`` kwarg has been added to the
|
||||
following visualization functions:
|
||||
|
||||
* ``qiskit.visualization.plot_state_hinton``
|
||||
* ``qiskit.visualization.plot_state_city``
|
||||
|
||||
These new kargs work the same as the newly added ``ax`` kwargs for other
|
||||
visualization functions. However because these plots use two axes (one for
|
||||
the real component, the other for the imaginary component). Having two
|
||||
kwargs also provides the flexibility to only generate a visualization for
|
||||
one of the components instead of always doing both. For example::
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
from qiskit.visualization import plot_state_hinton
|
||||
|
||||
ax = plt.gca()
|
||||
|
||||
plot_state_hinton(psi, ax_real=ax)
|
||||
|
||||
will only generate a plot of the real component.
|
Loading…
Reference in New Issue