mirror of https://github.com/Qiskit/qiskit.git
106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
# This code is part of Qiskit.
|
|
#
|
|
# (C) Copyright IBM 2024.
|
|
#
|
|
# This code is licensed under the Apache License, Version 2.0. You may
|
|
# obtain a copy of this license in the LICENSE.txt file in the root directory
|
|
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
|
|
#
|
|
# Any modifications or derivative works of this code must retain this
|
|
# copyright notice, and modified files need to carry a notice indicating
|
|
# that they have been altered from the originals.
|
|
|
|
"""Test shape.py module"""
|
|
|
|
|
|
import numpy as np
|
|
|
|
from qiskit.primitives.containers.shape import Shaped, ShapedMixin, array_coerce, shape_tuple
|
|
from test import QiskitTestCase # pylint: disable=wrong-import-order
|
|
|
|
|
|
class DummyShaped(ShapedMixin):
|
|
"""Dummy ShapedMixin child for testing."""
|
|
|
|
def __init__(self, arr):
|
|
super().__init__()
|
|
self._shape = arr.shape
|
|
self._arr = arr
|
|
|
|
def __getitem__(self, arg):
|
|
return self._arr[arg]
|
|
|
|
|
|
class ShapedTestCase(QiskitTestCase):
|
|
"""Test the Shaped protocol class"""
|
|
|
|
def test_ndarray_is_shaped(self):
|
|
"""Test that ndarrays are shaped"""
|
|
self.assertTrue(isinstance(np.empty((1, 2, 3)), Shaped))
|
|
|
|
def test_mixin_is_shaped(self):
|
|
"""Test that ShapedMixin is shaped"""
|
|
self.assertTrue(isinstance(DummyShaped(np.empty((1, 2, 3))), Shaped))
|
|
|
|
|
|
class ShapedMixinTestCase(QiskitTestCase):
|
|
"""Test the ShapedMixin class"""
|
|
|
|
def test_shape(self):
|
|
"""Test the shape attribute."""
|
|
self.assertEqual(DummyShaped(np.empty((1, 2, 3))).shape, (1, 2, 3))
|
|
self.assertEqual(DummyShaped(np.empty(())).shape, ())
|
|
|
|
def test_ndim(self):
|
|
"""Test the ndim attribute."""
|
|
self.assertEqual(DummyShaped(np.empty(())).ndim, 0)
|
|
self.assertEqual(DummyShaped(np.empty((1, 2, 3))).ndim, 3)
|
|
|
|
def test_size(self):
|
|
"""Test the size attribute."""
|
|
self.assertEqual(DummyShaped(np.empty(())).size, 1)
|
|
self.assertEqual(DummyShaped(np.empty((0, 1))).size, 0)
|
|
self.assertEqual(DummyShaped(np.empty((1, 2, 3))).size, 6)
|
|
|
|
def test_getitem(self):
|
|
"""Missing docstring."""
|
|
arr = np.arange(100).reshape(2, 5, 10)
|
|
np.testing.assert_allclose(DummyShaped(arr)[:, 0, :2], arr[:, 0, :2])
|
|
|
|
|
|
class ArrayCoerceTestCase(QiskitTestCase):
|
|
"""Test array_coerce() function."""
|
|
|
|
def test_shaped(self):
|
|
"""Test that array_coerce() works with ShapedMixin objects."""
|
|
sh = DummyShaped(np.empty((1, 2, 3)))
|
|
self.assertIs(sh, array_coerce(sh))
|
|
|
|
def test_ndarray(self):
|
|
"""Test that array_coerce() works with ndarray objects."""
|
|
sh = np.arange(100).reshape(5, 2, 2, 5)
|
|
np.testing.assert_allclose(sh, array_coerce(sh))
|
|
|
|
|
|
class ShapeTupleTestCase(QiskitTestCase):
|
|
"""Test shape_tuple() function."""
|
|
|
|
def test_int(self):
|
|
"""Test shape_tuple() with int inputs."""
|
|
self.assertEqual(shape_tuple(), ())
|
|
self.assertEqual(shape_tuple(5), (5,))
|
|
self.assertEqual(shape_tuple(5, 10), (5, 10))
|
|
self.assertEqual(shape_tuple(1e2), (100,))
|
|
|
|
def test_nested(self):
|
|
"""Test shape_tuple() with nested inputs."""
|
|
self.assertEqual(shape_tuple(0, (), (1, (2, (3,)), (4, 5))), (0, 1, 2, 3, 4, 5))
|
|
|
|
def test_exceptions(self):
|
|
"""Test shape_tuple() raises correctly."""
|
|
with self.assertRaisesRegex(ValueError, "iterable or an integer"):
|
|
shape_tuple(None)
|
|
|
|
with self.assertRaisesRegex(ValueError, "iterable or an integer"):
|
|
shape_tuple(1.5)
|