Fixed error in plot_histogram when number_to_keep is smaller that the number of keys (#7481)

* plot_histogram fails when given number_to_keep parameter #7461

* Fix style

* fix review comments

* remove temporary the release notes

* add release notes

* Reformat release notes

* Temp - remove code from release note

* Add code example in release note.

* Add code example in release note.

* Add code example in release note.

* add test for multiple executions display

* change dictionary from OderedDict to defacultdict

* small fixes

* Correct use of optional tests

* Reword release note

* Improve documentation of `number_to_keep`

* Fix crash on distance measures

* Revert if.

* refactor code & add test

Co-authored-by: Jake Lishman <jake.lishman@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Iulia Zidaru 2022-06-07 16:09:06 +03:00 committed by GitHub
parent a84380d431
commit 1658bb4626
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 165 additions and 13 deletions

View File

@ -14,7 +14,7 @@
Visualization functions for measurement counts.
"""
from collections import Counter, OrderedDict
from collections import OrderedDict
import functools
import numpy as np
@ -64,8 +64,11 @@ def plot_histogram(
dict containing the values to represent (ex {'001': 130})
figsize (tuple): Figure size in inches.
color (list or str): String or list of strings for histogram bar colors.
number_to_keep (int): The number of terms to plot and rest
is made into a single bar called 'rest'.
number_to_keep (int): The number of terms to plot per dataset. The rest is made into a
single bar called 'rest'. If multiple datasets are given, the ``number_to_keep``
applies to each dataset individually, which may result in more bars than
``number_to_keep + 1``. The ``number_to_keep`` applies to the total values, rather than
the x-axis sort.
sort (string): Could be `'asc'`, `'desc'`, `'hamming'`, `'value'`, or
`'value_desc'`. If set to `'value'` or `'value_desc'` the x axis
will be sorted by the maximum probability for each bitstring.
@ -148,7 +151,7 @@ def plot_histogram(
if sort in DIST_MEAS:
dist = []
for item in labels:
dist.append(DIST_MEAS[sort](item, target_string))
dist.append(DIST_MEAS[sort](item, target_string) if item != "rest" else 0)
labels = [list(x) for x in zip(*sorted(zip(dist, labels), key=lambda pair: pair[0]))][1]
elif "value" in sort:
@ -241,6 +244,26 @@ def plot_histogram(
return fig.savefig(filename)
def _keep_largest_items(execution, number_to_keep):
"""Keep only the largest values in a dictionary, and sum the rest into a new key 'rest'."""
sorted_counts = sorted(execution.items(), key=lambda p: p[1])
rest = sum(count for key, count in sorted_counts[:-number_to_keep])
return dict(sorted_counts[-number_to_keep:], rest=rest)
def _unify_labels(data):
"""Make all dictionaries in data have the same set of keys, using 0 for missing values."""
data = tuple(data)
all_labels = set().union(*(execution.keys() for execution in data))
base = {label: 0 for label in all_labels}
out = []
for execution in data:
new_execution = base.copy()
new_execution.update(execution)
out.append(new_execution)
return out
def _plot_histogram_data(data, labels, number_to_keep):
"""Generate the data needed for plotting counts.
@ -259,22 +282,21 @@ def _plot_histogram_data(data, labels, number_to_keep):
experiment.
"""
labels_dict = OrderedDict()
all_pvalues = []
all_inds = []
if isinstance(data, dict):
data = [data]
if number_to_keep is not None:
data = _unify_labels(_keep_largest_items(execution, number_to_keep) for execution in data)
for execution in data:
if number_to_keep is not None:
data_temp = dict(Counter(execution).most_common(number_to_keep))
data_temp["rest"] = sum(execution.values()) - sum(data_temp.values())
execution = data_temp
values = []
for key in labels:
if key not in execution:
if number_to_keep is None:
labels_dict[key] = 1
values.append(0)
else:
values.append(-1)
else:
labels_dict[key] = 1
values.append(execution[key])

View File

@ -0,0 +1,10 @@
---
fixes:
- |
Fixed a bug in :func:`~qiskit.visualization.plot_histogram` when the
``number_to_keep`` argument was smaller that the number of keys. The
following code will not throw errors and will be properly aligned::
from qiskit.visualization import plot_histogram
data = {'00': 3, '01': 5, '11': 8, '10': 11}
plot_histogram(data, number_to_keep=2)

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

View File

@ -171,6 +171,19 @@ class TestGraphMatplotlibDrawer(QiskitTestCase):
self.graph_count_drawer(counts, filename="histogram.png")
def test_plot_histogram_with_rest(self):
"""test plot_histogram with 2 datasets and number_to_keep"""
data = [{"00": 3, "01": 5, "10": 6, "11": 12}]
self.graph_count_drawer(data, number_to_keep=2, filename="histogram_with_rest.png")
def test_plot_histogram_2_sets_with_rest(self):
"""test plot_histogram with 2 datasets and number_to_keep"""
data = [
{"00": 3, "01": 5, "10": 6, "11": 12},
{"00": 5, "01": 7, "10": 6, "11": 12},
]
self.graph_count_drawer(data, number_to_keep=2, filename="histogram_2_sets_with_rest.png")
def test_plot_histogram_color(self):
"""Test histogram with single color"""

View File

@ -13,15 +13,21 @@
"""Tests for plot_histogram."""
import unittest
from io import BytesIO
from collections import Counter
import matplotlib as mpl
from PIL import Image
from qiskit.test import QiskitTestCase
from qiskit.tools.visualization import plot_histogram
from qiskit.utils import optionals
from .visualization import QiskitVisualizationTestCase
class TestPlotHistogram(QiskitTestCase):
class TestPlotHistogram(QiskitVisualizationTestCase):
"""Qiskit plot_histogram tests."""
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_different_counts_lengths(self):
"""Test plotting two different length dists works"""
exact_dist = {
@ -107,6 +113,107 @@ class TestPlotHistogram(QiskitTestCase):
fig = plot_histogram([raw_dist, exact_dist])
self.assertIsInstance(fig, mpl.figure.Figure)
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_with_number_to_keep(self):
"""Test plotting using number_to_keep"""
dist = {"00": 3, "01": 5, "11": 8, "10": 11}
fig = plot_histogram(dist, number_to_keep=2)
self.assertIsInstance(fig, mpl.figure.Figure)
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_with_number_to_keep_multiple_executions(self):
"""Test plotting using number_to_keep with multiple executions"""
dist = [{"00": 3, "01": 5, "11": 8, "10": 11}, {"00": 3, "01": 7, "10": 11}]
fig = plot_histogram(dist, number_to_keep=2)
self.assertIsInstance(fig, mpl.figure.Figure)
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_with_number_to_keep_multiple_executions_correct_image(self):
"""Test plotting using number_to_keep with multiple executions"""
data_noisy = {
"00000": 0.22,
"00001": 0.003,
"00010": 0.005,
"00011": 0.0,
"00100": 0.004,
"00101": 0.001,
"00110": 0.004,
"00111": 0.001,
"01000": 0.005,
"01001": 0.0,
"01010": 0.002,
"01011": 0.0,
"01100": 0.225,
"01101": 0.001,
"01110": 0.003,
"01111": 0.003,
"10000": 0.012,
"10001": 0.002,
"10010": 0.001,
"10011": 0.001,
"10100": 0.247,
"10101": 0.004,
"10110": 0.003,
"10111": 0.001,
"11000": 0.225,
"11001": 0.005,
"11010": 0.002,
"11011": 0.0,
"11100": 0.015,
"11101": 0.004,
"11110": 0.001,
"11111": 0.0,
}
data_ideal = {
"00000": 0.25,
"00001": 0,
"00010": 0,
"00011": 0,
"00100": 0,
"00101": 0,
"00110": 0,
"00111": 0.0,
"01000": 0.0,
"01001": 0,
"01010": 0.0,
"01011": 0.0,
"01100": 0.25,
"01101": 0,
"01110": 0,
"01111": 0,
"10000": 0,
"10001": 0,
"10010": 0.0,
"10011": 0.0,
"10100": 0.25,
"10101": 0,
"10110": 0,
"10111": 0,
"11000": 0.25,
"11001": 0,
"11010": 0,
"11011": 0,
"11100": 0.0,
"11101": 0,
"11110": 0,
"11111": 0.0,
}
data_ref_noisy = dict(Counter(data_noisy).most_common(5))
data_ref_noisy["rest"] = sum(data_noisy.values()) - sum(data_ref_noisy.values())
data_ref_ideal = dict(Counter(data_ideal).most_common(4)) # do not add 0 values
data_ref_ideal["rest"] = 0
figure_ref = plot_histogram([data_ref_ideal, data_ref_noisy])
figure_truncated = plot_histogram([data_ideal, data_noisy], number_to_keep=5)
with BytesIO() as img_buffer_ref:
figure_ref.savefig(img_buffer_ref, format="png")
img_buffer_ref.seek(0)
with BytesIO() as img_buffer:
figure_truncated.savefig(img_buffer, format="png")
img_buffer.seek(0)
self.assertImagesAreEqual(Image.open(img_buffer_ref), Image.open(img_buffer), 0.2)
mpl.pyplot.close(figure_ref)
mpl.pyplot.close(figure_truncated)
if __name__ == "__main__":
unittest.main(verbosity=2)