qiskit/test/python/quantum_info/states/test_random.py

106 lines
4.0 KiB
Python

# This code is part of Qiskit.
#
# (C) Copyright IBM 2017, 2020.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Test random operator functions."""
import unittest
from test import combine
from ddt import ddt
import numpy as np
from qiskit.quantum_info import Statevector, DensityMatrix
from qiskit.quantum_info.random import random_statevector
from qiskit.quantum_info.random import random_density_matrix
from test import QiskitTestCase # pylint: disable=wrong-import-order
@ddt
class TestRandomStatevector(QiskitTestCase):
"""Testing random_unitary function."""
@combine(dims=[(2,), (3,), (2, 2), (2, 3)])
def test_tuple_dims(self, dims):
"""Test random_statevector is valid with dims {dims}."""
value = random_statevector(dims)
self.assertIsInstance(value, Statevector)
self.assertTrue(value.is_valid())
self.assertEqual(value.dims(), dims)
@combine(dim=[2, 3, 4, 5])
def test_int_dims(self, dim):
"""Test random_statevector is valid with dims {dim}."""
value = random_statevector(dim)
self.assertIsInstance(value, Statevector)
self.assertTrue(value.is_valid())
self.assertEqual(np.prod(value.dims()), dim)
def test_fixed_seed(self):
"""Test fixing seed fixes output"""
seed = 1532
value1 = random_statevector(4, seed=seed)
value2 = random_statevector(4, seed=seed)
self.assertEqual(value1, value2)
def test_not_global_seed(self):
"""Test fixing seed is locally scoped."""
seed = 314159
test_cases = 100
random_statevector(2, seed=seed)
rng_before = np.random.randint(1000, size=test_cases)
random_statevector(2, seed=seed)
rng_after = np.random.randint(1000, size=test_cases)
self.assertFalse(np.all(rng_before == rng_after))
@ddt
class TestRandomDensityMatrix(QiskitTestCase):
"""Testing random_density_matrix function."""
@combine(dims=[(2,), (3,), (2, 2), (2, 3)], method=["Hilbert-Schmidt", "Bures"])
def test_tuple_dims(self, dims, method):
"""Test random_density_matrix {method} method is valid with dims {dims}."""
value = random_density_matrix(dims, method=method)
self.assertIsInstance(value, DensityMatrix)
self.assertTrue(value.is_valid())
self.assertEqual(value.dims(), dims)
@combine(dim=[2, 3, 4, 5], method=["Hilbert-Schmidt", "Bures"])
def test_int_dims(self, dim, method):
"""Test random_density_matrix {method} method is valid with dims {dim}."""
value = random_density_matrix(dim, method=method)
self.assertIsInstance(value, DensityMatrix)
self.assertTrue(value.is_valid())
self.assertEqual(np.prod(value.dims()), dim)
@combine(method=["Hilbert-Schmidt", "Bures"])
def test_fixed_seed(self, method):
"""Test fixing seed fixes output ({method} method)"""
seed = 1532
value1 = random_density_matrix(4, method=method, seed=seed)
value2 = random_density_matrix(4, method=method, seed=seed)
self.assertEqual(value1, value2)
@combine(method=["Hilbert-Schmidt", "Bures"])
def test_not_global_seed(self, method):
"""Test fixing seed is locally scoped ({method} method)."""
seed = 314159
test_cases = 100
random_density_matrix(2, method=method, seed=seed)
rng_before = np.random.randint(1000, size=test_cases)
random_density_matrix(2, method=method, seed=seed)
rng_after = np.random.randint(1000, size=test_cases)
self.assertFalse(np.all(rng_before == rng_after))
if __name__ == "__main__":
unittest.main()