mirror of https://github.com/Qiskit/qiskit.git
Fixed error in plot_histogram when number_to_keep is smaller that the number of keys (#7481)
* plot_histogram fails when given number_to_keep parameter #7461 * Fix style * fix review comments * remove temporary the release notes * add release notes * Reformat release notes * Temp - remove code from release note * Add code example in release note. * Add code example in release note. * Add code example in release note. * add test for multiple executions display * change dictionary from OderedDict to defacultdict * small fixes * Correct use of optional tests * Reword release note * Improve documentation of `number_to_keep` * Fix crash on distance measures * Revert if. * refactor code & add test Co-authored-by: Jake Lishman <jake.lishman@ibm.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
a84380d431
commit
1658bb4626
|
@ -14,7 +14,7 @@
|
|||
Visualization functions for measurement counts.
|
||||
"""
|
||||
|
||||
from collections import Counter, OrderedDict
|
||||
from collections import OrderedDict
|
||||
import functools
|
||||
import numpy as np
|
||||
|
||||
|
@ -64,8 +64,11 @@ def plot_histogram(
|
|||
dict containing the values to represent (ex {'001': 130})
|
||||
figsize (tuple): Figure size in inches.
|
||||
color (list or str): String or list of strings for histogram bar colors.
|
||||
number_to_keep (int): The number of terms to plot and rest
|
||||
is made into a single bar called 'rest'.
|
||||
number_to_keep (int): The number of terms to plot per dataset. The rest is made into a
|
||||
single bar called 'rest'. If multiple datasets are given, the ``number_to_keep``
|
||||
applies to each dataset individually, which may result in more bars than
|
||||
``number_to_keep + 1``. The ``number_to_keep`` applies to the total values, rather than
|
||||
the x-axis sort.
|
||||
sort (string): Could be `'asc'`, `'desc'`, `'hamming'`, `'value'`, or
|
||||
`'value_desc'`. If set to `'value'` or `'value_desc'` the x axis
|
||||
will be sorted by the maximum probability for each bitstring.
|
||||
|
@ -148,7 +151,7 @@ def plot_histogram(
|
|||
if sort in DIST_MEAS:
|
||||
dist = []
|
||||
for item in labels:
|
||||
dist.append(DIST_MEAS[sort](item, target_string))
|
||||
dist.append(DIST_MEAS[sort](item, target_string) if item != "rest" else 0)
|
||||
|
||||
labels = [list(x) for x in zip(*sorted(zip(dist, labels), key=lambda pair: pair[0]))][1]
|
||||
elif "value" in sort:
|
||||
|
@ -241,6 +244,26 @@ def plot_histogram(
|
|||
return fig.savefig(filename)
|
||||
|
||||
|
||||
def _keep_largest_items(execution, number_to_keep):
|
||||
"""Keep only the largest values in a dictionary, and sum the rest into a new key 'rest'."""
|
||||
sorted_counts = sorted(execution.items(), key=lambda p: p[1])
|
||||
rest = sum(count for key, count in sorted_counts[:-number_to_keep])
|
||||
return dict(sorted_counts[-number_to_keep:], rest=rest)
|
||||
|
||||
|
||||
def _unify_labels(data):
|
||||
"""Make all dictionaries in data have the same set of keys, using 0 for missing values."""
|
||||
data = tuple(data)
|
||||
all_labels = set().union(*(execution.keys() for execution in data))
|
||||
base = {label: 0 for label in all_labels}
|
||||
out = []
|
||||
for execution in data:
|
||||
new_execution = base.copy()
|
||||
new_execution.update(execution)
|
||||
out.append(new_execution)
|
||||
return out
|
||||
|
||||
|
||||
def _plot_histogram_data(data, labels, number_to_keep):
|
||||
"""Generate the data needed for plotting counts.
|
||||
|
||||
|
@ -259,22 +282,21 @@ def _plot_histogram_data(data, labels, number_to_keep):
|
|||
experiment.
|
||||
"""
|
||||
labels_dict = OrderedDict()
|
||||
|
||||
all_pvalues = []
|
||||
all_inds = []
|
||||
|
||||
if isinstance(data, dict):
|
||||
data = [data]
|
||||
if number_to_keep is not None:
|
||||
data = _unify_labels(_keep_largest_items(execution, number_to_keep) for execution in data)
|
||||
|
||||
for execution in data:
|
||||
if number_to_keep is not None:
|
||||
data_temp = dict(Counter(execution).most_common(number_to_keep))
|
||||
data_temp["rest"] = sum(execution.values()) - sum(data_temp.values())
|
||||
execution = data_temp
|
||||
values = []
|
||||
for key in labels:
|
||||
if key not in execution:
|
||||
if number_to_keep is None:
|
||||
labels_dict[key] = 1
|
||||
values.append(0)
|
||||
else:
|
||||
values.append(-1)
|
||||
else:
|
||||
labels_dict[key] = 1
|
||||
values.append(execution[key])
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
---
|
||||
fixes:
|
||||
- |
|
||||
Fixed a bug in :func:`~qiskit.visualization.plot_histogram` when the
|
||||
``number_to_keep`` argument was smaller that the number of keys. The
|
||||
following code will not throw errors and will be properly aligned::
|
||||
|
||||
from qiskit.visualization import plot_histogram
|
||||
data = {'00': 3, '01': 5, '11': 8, '10': 11}
|
||||
plot_histogram(data, number_to_keep=2)
|
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
Binary file not shown.
After Width: | Height: | Size: 17 KiB |
|
@ -171,6 +171,19 @@ class TestGraphMatplotlibDrawer(QiskitTestCase):
|
|||
|
||||
self.graph_count_drawer(counts, filename="histogram.png")
|
||||
|
||||
def test_plot_histogram_with_rest(self):
|
||||
"""test plot_histogram with 2 datasets and number_to_keep"""
|
||||
data = [{"00": 3, "01": 5, "10": 6, "11": 12}]
|
||||
self.graph_count_drawer(data, number_to_keep=2, filename="histogram_with_rest.png")
|
||||
|
||||
def test_plot_histogram_2_sets_with_rest(self):
|
||||
"""test plot_histogram with 2 datasets and number_to_keep"""
|
||||
data = [
|
||||
{"00": 3, "01": 5, "10": 6, "11": 12},
|
||||
{"00": 5, "01": 7, "10": 6, "11": 12},
|
||||
]
|
||||
self.graph_count_drawer(data, number_to_keep=2, filename="histogram_2_sets_with_rest.png")
|
||||
|
||||
def test_plot_histogram_color(self):
|
||||
"""Test histogram with single color"""
|
||||
|
||||
|
|
|
@ -13,15 +13,21 @@
|
|||
"""Tests for plot_histogram."""
|
||||
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
from collections import Counter
|
||||
|
||||
import matplotlib as mpl
|
||||
from PIL import Image
|
||||
|
||||
from qiskit.test import QiskitTestCase
|
||||
from qiskit.tools.visualization import plot_histogram
|
||||
from qiskit.utils import optionals
|
||||
from .visualization import QiskitVisualizationTestCase
|
||||
|
||||
|
||||
class TestPlotHistogram(QiskitTestCase):
|
||||
class TestPlotHistogram(QiskitVisualizationTestCase):
|
||||
"""Qiskit plot_histogram tests."""
|
||||
|
||||
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
|
||||
def test_different_counts_lengths(self):
|
||||
"""Test plotting two different length dists works"""
|
||||
exact_dist = {
|
||||
|
@ -107,6 +113,107 @@ class TestPlotHistogram(QiskitTestCase):
|
|||
fig = plot_histogram([raw_dist, exact_dist])
|
||||
self.assertIsInstance(fig, mpl.figure.Figure)
|
||||
|
||||
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
|
||||
def test_with_number_to_keep(self):
|
||||
"""Test plotting using number_to_keep"""
|
||||
dist = {"00": 3, "01": 5, "11": 8, "10": 11}
|
||||
fig = plot_histogram(dist, number_to_keep=2)
|
||||
self.assertIsInstance(fig, mpl.figure.Figure)
|
||||
|
||||
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
|
||||
def test_with_number_to_keep_multiple_executions(self):
|
||||
"""Test plotting using number_to_keep with multiple executions"""
|
||||
dist = [{"00": 3, "01": 5, "11": 8, "10": 11}, {"00": 3, "01": 7, "10": 11}]
|
||||
fig = plot_histogram(dist, number_to_keep=2)
|
||||
self.assertIsInstance(fig, mpl.figure.Figure)
|
||||
|
||||
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
|
||||
def test_with_number_to_keep_multiple_executions_correct_image(self):
|
||||
"""Test plotting using number_to_keep with multiple executions"""
|
||||
data_noisy = {
|
||||
"00000": 0.22,
|
||||
"00001": 0.003,
|
||||
"00010": 0.005,
|
||||
"00011": 0.0,
|
||||
"00100": 0.004,
|
||||
"00101": 0.001,
|
||||
"00110": 0.004,
|
||||
"00111": 0.001,
|
||||
"01000": 0.005,
|
||||
"01001": 0.0,
|
||||
"01010": 0.002,
|
||||
"01011": 0.0,
|
||||
"01100": 0.225,
|
||||
"01101": 0.001,
|
||||
"01110": 0.003,
|
||||
"01111": 0.003,
|
||||
"10000": 0.012,
|
||||
"10001": 0.002,
|
||||
"10010": 0.001,
|
||||
"10011": 0.001,
|
||||
"10100": 0.247,
|
||||
"10101": 0.004,
|
||||
"10110": 0.003,
|
||||
"10111": 0.001,
|
||||
"11000": 0.225,
|
||||
"11001": 0.005,
|
||||
"11010": 0.002,
|
||||
"11011": 0.0,
|
||||
"11100": 0.015,
|
||||
"11101": 0.004,
|
||||
"11110": 0.001,
|
||||
"11111": 0.0,
|
||||
}
|
||||
data_ideal = {
|
||||
"00000": 0.25,
|
||||
"00001": 0,
|
||||
"00010": 0,
|
||||
"00011": 0,
|
||||
"00100": 0,
|
||||
"00101": 0,
|
||||
"00110": 0,
|
||||
"00111": 0.0,
|
||||
"01000": 0.0,
|
||||
"01001": 0,
|
||||
"01010": 0.0,
|
||||
"01011": 0.0,
|
||||
"01100": 0.25,
|
||||
"01101": 0,
|
||||
"01110": 0,
|
||||
"01111": 0,
|
||||
"10000": 0,
|
||||
"10001": 0,
|
||||
"10010": 0.0,
|
||||
"10011": 0.0,
|
||||
"10100": 0.25,
|
||||
"10101": 0,
|
||||
"10110": 0,
|
||||
"10111": 0,
|
||||
"11000": 0.25,
|
||||
"11001": 0,
|
||||
"11010": 0,
|
||||
"11011": 0,
|
||||
"11100": 0.0,
|
||||
"11101": 0,
|
||||
"11110": 0,
|
||||
"11111": 0.0,
|
||||
}
|
||||
data_ref_noisy = dict(Counter(data_noisy).most_common(5))
|
||||
data_ref_noisy["rest"] = sum(data_noisy.values()) - sum(data_ref_noisy.values())
|
||||
data_ref_ideal = dict(Counter(data_ideal).most_common(4)) # do not add 0 values
|
||||
data_ref_ideal["rest"] = 0
|
||||
figure_ref = plot_histogram([data_ref_ideal, data_ref_noisy])
|
||||
figure_truncated = plot_histogram([data_ideal, data_noisy], number_to_keep=5)
|
||||
with BytesIO() as img_buffer_ref:
|
||||
figure_ref.savefig(img_buffer_ref, format="png")
|
||||
img_buffer_ref.seek(0)
|
||||
with BytesIO() as img_buffer:
|
||||
figure_truncated.savefig(img_buffer, format="png")
|
||||
img_buffer.seek(0)
|
||||
self.assertImagesAreEqual(Image.open(img_buffer_ref), Image.open(img_buffer), 0.2)
|
||||
mpl.pyplot.close(figure_ref)
|
||||
mpl.pyplot.close(figure_truncated)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
|
Loading…
Reference in New Issue