Fixed error in plot_histogram when number_to_keep is smaller that the number of keys (#7481)

* plot_histogram fails when given number_to_keep parameter #7461

* Fix style

* fix review comments

* remove temporary the release notes

* add release notes

* Reformat release notes

* Temp - remove code from release note

* Add code example in release note.

* Add code example in release note.

* Add code example in release note.

* add test for multiple executions display

* change dictionary from OderedDict to defacultdict

* small fixes

* Correct use of optional tests

* Reword release note

* Improve documentation of `number_to_keep`

* Fix crash on distance measures

* Revert if.

* refactor code & add test

Co-authored-by: Jake Lishman <jake.lishman@ibm.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Iulia Zidaru 2022-06-07 16:09:06 +03:00 committed by GitHub
parent a84380d431
commit 1658bb4626
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 165 additions and 13 deletions

View File

@ -14,7 +14,7 @@
Visualization functions for measurement counts. Visualization functions for measurement counts.
""" """
from collections import Counter, OrderedDict from collections import OrderedDict
import functools import functools
import numpy as np import numpy as np
@ -64,8 +64,11 @@ def plot_histogram(
dict containing the values to represent (ex {'001': 130}) dict containing the values to represent (ex {'001': 130})
figsize (tuple): Figure size in inches. figsize (tuple): Figure size in inches.
color (list or str): String or list of strings for histogram bar colors. color (list or str): String or list of strings for histogram bar colors.
number_to_keep (int): The number of terms to plot and rest number_to_keep (int): The number of terms to plot per dataset. The rest is made into a
is made into a single bar called 'rest'. single bar called 'rest'. If multiple datasets are given, the ``number_to_keep``
applies to each dataset individually, which may result in more bars than
``number_to_keep + 1``. The ``number_to_keep`` applies to the total values, rather than
the x-axis sort.
sort (string): Could be `'asc'`, `'desc'`, `'hamming'`, `'value'`, or sort (string): Could be `'asc'`, `'desc'`, `'hamming'`, `'value'`, or
`'value_desc'`. If set to `'value'` or `'value_desc'` the x axis `'value_desc'`. If set to `'value'` or `'value_desc'` the x axis
will be sorted by the maximum probability for each bitstring. will be sorted by the maximum probability for each bitstring.
@ -148,7 +151,7 @@ def plot_histogram(
if sort in DIST_MEAS: if sort in DIST_MEAS:
dist = [] dist = []
for item in labels: for item in labels:
dist.append(DIST_MEAS[sort](item, target_string)) dist.append(DIST_MEAS[sort](item, target_string) if item != "rest" else 0)
labels = [list(x) for x in zip(*sorted(zip(dist, labels), key=lambda pair: pair[0]))][1] labels = [list(x) for x in zip(*sorted(zip(dist, labels), key=lambda pair: pair[0]))][1]
elif "value" in sort: elif "value" in sort:
@ -241,6 +244,26 @@ def plot_histogram(
return fig.savefig(filename) return fig.savefig(filename)
def _keep_largest_items(execution, number_to_keep):
"""Keep only the largest values in a dictionary, and sum the rest into a new key 'rest'."""
sorted_counts = sorted(execution.items(), key=lambda p: p[1])
rest = sum(count for key, count in sorted_counts[:-number_to_keep])
return dict(sorted_counts[-number_to_keep:], rest=rest)
def _unify_labels(data):
"""Make all dictionaries in data have the same set of keys, using 0 for missing values."""
data = tuple(data)
all_labels = set().union(*(execution.keys() for execution in data))
base = {label: 0 for label in all_labels}
out = []
for execution in data:
new_execution = base.copy()
new_execution.update(execution)
out.append(new_execution)
return out
def _plot_histogram_data(data, labels, number_to_keep): def _plot_histogram_data(data, labels, number_to_keep):
"""Generate the data needed for plotting counts. """Generate the data needed for plotting counts.
@ -259,22 +282,21 @@ def _plot_histogram_data(data, labels, number_to_keep):
experiment. experiment.
""" """
labels_dict = OrderedDict() labels_dict = OrderedDict()
all_pvalues = [] all_pvalues = []
all_inds = [] all_inds = []
if isinstance(data, dict):
data = [data]
if number_to_keep is not None:
data = _unify_labels(_keep_largest_items(execution, number_to_keep) for execution in data)
for execution in data: for execution in data:
if number_to_keep is not None:
data_temp = dict(Counter(execution).most_common(number_to_keep))
data_temp["rest"] = sum(execution.values()) - sum(data_temp.values())
execution = data_temp
values = [] values = []
for key in labels: for key in labels:
if key not in execution: if key not in execution:
if number_to_keep is None: if number_to_keep is None:
labels_dict[key] = 1 labels_dict[key] = 1
values.append(0) values.append(0)
else:
values.append(-1)
else: else:
labels_dict[key] = 1 labels_dict[key] = 1
values.append(execution[key]) values.append(execution[key])

View File

@ -0,0 +1,10 @@
---
fixes:
- |
Fixed a bug in :func:`~qiskit.visualization.plot_histogram` when the
``number_to_keep`` argument was smaller that the number of keys. The
following code will not throw errors and will be properly aligned::
from qiskit.visualization import plot_histogram
data = {'00': 3, '01': 5, '11': 8, '10': 11}
plot_histogram(data, number_to_keep=2)

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

View File

@ -171,6 +171,19 @@ class TestGraphMatplotlibDrawer(QiskitTestCase):
self.graph_count_drawer(counts, filename="histogram.png") self.graph_count_drawer(counts, filename="histogram.png")
def test_plot_histogram_with_rest(self):
"""test plot_histogram with 2 datasets and number_to_keep"""
data = [{"00": 3, "01": 5, "10": 6, "11": 12}]
self.graph_count_drawer(data, number_to_keep=2, filename="histogram_with_rest.png")
def test_plot_histogram_2_sets_with_rest(self):
"""test plot_histogram with 2 datasets and number_to_keep"""
data = [
{"00": 3, "01": 5, "10": 6, "11": 12},
{"00": 5, "01": 7, "10": 6, "11": 12},
]
self.graph_count_drawer(data, number_to_keep=2, filename="histogram_2_sets_with_rest.png")
def test_plot_histogram_color(self): def test_plot_histogram_color(self):
"""Test histogram with single color""" """Test histogram with single color"""

View File

@ -13,15 +13,21 @@
"""Tests for plot_histogram.""" """Tests for plot_histogram."""
import unittest import unittest
from io import BytesIO
from collections import Counter
import matplotlib as mpl import matplotlib as mpl
from PIL import Image
from qiskit.test import QiskitTestCase
from qiskit.tools.visualization import plot_histogram from qiskit.tools.visualization import plot_histogram
from qiskit.utils import optionals
from .visualization import QiskitVisualizationTestCase
class TestPlotHistogram(QiskitTestCase): class TestPlotHistogram(QiskitVisualizationTestCase):
"""Qiskit plot_histogram tests.""" """Qiskit plot_histogram tests."""
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_different_counts_lengths(self): def test_different_counts_lengths(self):
"""Test plotting two different length dists works""" """Test plotting two different length dists works"""
exact_dist = { exact_dist = {
@ -107,6 +113,107 @@ class TestPlotHistogram(QiskitTestCase):
fig = plot_histogram([raw_dist, exact_dist]) fig = plot_histogram([raw_dist, exact_dist])
self.assertIsInstance(fig, mpl.figure.Figure) self.assertIsInstance(fig, mpl.figure.Figure)
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_with_number_to_keep(self):
"""Test plotting using number_to_keep"""
dist = {"00": 3, "01": 5, "11": 8, "10": 11}
fig = plot_histogram(dist, number_to_keep=2)
self.assertIsInstance(fig, mpl.figure.Figure)
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_with_number_to_keep_multiple_executions(self):
"""Test plotting using number_to_keep with multiple executions"""
dist = [{"00": 3, "01": 5, "11": 8, "10": 11}, {"00": 3, "01": 7, "10": 11}]
fig = plot_histogram(dist, number_to_keep=2)
self.assertIsInstance(fig, mpl.figure.Figure)
@unittest.skipUnless(optionals.HAS_MATPLOTLIB, "matplotlib not available.")
def test_with_number_to_keep_multiple_executions_correct_image(self):
"""Test plotting using number_to_keep with multiple executions"""
data_noisy = {
"00000": 0.22,
"00001": 0.003,
"00010": 0.005,
"00011": 0.0,
"00100": 0.004,
"00101": 0.001,
"00110": 0.004,
"00111": 0.001,
"01000": 0.005,
"01001": 0.0,
"01010": 0.002,
"01011": 0.0,
"01100": 0.225,
"01101": 0.001,
"01110": 0.003,
"01111": 0.003,
"10000": 0.012,
"10001": 0.002,
"10010": 0.001,
"10011": 0.001,
"10100": 0.247,
"10101": 0.004,
"10110": 0.003,
"10111": 0.001,
"11000": 0.225,
"11001": 0.005,
"11010": 0.002,
"11011": 0.0,
"11100": 0.015,
"11101": 0.004,
"11110": 0.001,
"11111": 0.0,
}
data_ideal = {
"00000": 0.25,
"00001": 0,
"00010": 0,
"00011": 0,
"00100": 0,
"00101": 0,
"00110": 0,
"00111": 0.0,
"01000": 0.0,
"01001": 0,
"01010": 0.0,
"01011": 0.0,
"01100": 0.25,
"01101": 0,
"01110": 0,
"01111": 0,
"10000": 0,
"10001": 0,
"10010": 0.0,
"10011": 0.0,
"10100": 0.25,
"10101": 0,
"10110": 0,
"10111": 0,
"11000": 0.25,
"11001": 0,
"11010": 0,
"11011": 0,
"11100": 0.0,
"11101": 0,
"11110": 0,
"11111": 0.0,
}
data_ref_noisy = dict(Counter(data_noisy).most_common(5))
data_ref_noisy["rest"] = sum(data_noisy.values()) - sum(data_ref_noisy.values())
data_ref_ideal = dict(Counter(data_ideal).most_common(4)) # do not add 0 values
data_ref_ideal["rest"] = 0
figure_ref = plot_histogram([data_ref_ideal, data_ref_noisy])
figure_truncated = plot_histogram([data_ideal, data_noisy], number_to_keep=5)
with BytesIO() as img_buffer_ref:
figure_ref.savefig(img_buffer_ref, format="png")
img_buffer_ref.seek(0)
with BytesIO() as img_buffer:
figure_truncated.savefig(img_buffer, format="png")
img_buffer.seek(0)
self.assertImagesAreEqual(Image.open(img_buffer_ref), Image.open(img_buffer), 0.2)
mpl.pyplot.close(figure_ref)
mpl.pyplot.close(figure_truncated)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(verbosity=2) unittest.main(verbosity=2)