mirror of https://github.com/Qiskit/qiskit.git
106 lines
4.0 KiB
Python
106 lines
4.0 KiB
Python
# This code is part of Qiskit.
|
|
#
|
|
# (C) Copyright IBM 2017, 2020.
|
|
#
|
|
# This code is licensed under the Apache License, Version 2.0. You may
|
|
# obtain a copy of this license in the LICENSE.txt file in the root directory
|
|
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
|
|
#
|
|
# Any modifications or derivative works of this code must retain this
|
|
# copyright notice, and modified files need to carry a notice indicating
|
|
# that they have been altered from the originals.
|
|
|
|
"""Test random operator functions."""
|
|
|
|
import unittest
|
|
from test import combine
|
|
from ddt import ddt
|
|
import numpy as np
|
|
|
|
from qiskit.quantum_info import Statevector, DensityMatrix
|
|
from qiskit.quantum_info.random import random_statevector
|
|
from qiskit.quantum_info.random import random_density_matrix
|
|
from test import QiskitTestCase # pylint: disable=wrong-import-order
|
|
|
|
|
|
@ddt
|
|
class TestRandomStatevector(QiskitTestCase):
|
|
"""Testing random_unitary function."""
|
|
|
|
@combine(dims=[(2,), (3,), (2, 2), (2, 3)])
|
|
def test_tuple_dims(self, dims):
|
|
"""Test random_statevector is valid with dims {dims}."""
|
|
value = random_statevector(dims)
|
|
self.assertIsInstance(value, Statevector)
|
|
self.assertTrue(value.is_valid())
|
|
self.assertEqual(value.dims(), dims)
|
|
|
|
@combine(dim=[2, 3, 4, 5])
|
|
def test_int_dims(self, dim):
|
|
"""Test random_statevector is valid with dims {dim}."""
|
|
value = random_statevector(dim)
|
|
self.assertIsInstance(value, Statevector)
|
|
self.assertTrue(value.is_valid())
|
|
self.assertEqual(np.prod(value.dims()), dim)
|
|
|
|
def test_fixed_seed(self):
|
|
"""Test fixing seed fixes output"""
|
|
seed = 1532
|
|
value1 = random_statevector(4, seed=seed)
|
|
value2 = random_statevector(4, seed=seed)
|
|
self.assertEqual(value1, value2)
|
|
|
|
def test_not_global_seed(self):
|
|
"""Test fixing seed is locally scoped."""
|
|
seed = 314159
|
|
test_cases = 100
|
|
random_statevector(2, seed=seed)
|
|
rng_before = np.random.randint(1000, size=test_cases)
|
|
random_statevector(2, seed=seed)
|
|
rng_after = np.random.randint(1000, size=test_cases)
|
|
self.assertFalse(np.all(rng_before == rng_after))
|
|
|
|
|
|
@ddt
|
|
class TestRandomDensityMatrix(QiskitTestCase):
|
|
"""Testing random_density_matrix function."""
|
|
|
|
@combine(dims=[(2,), (3,), (2, 2), (2, 3)], method=["Hilbert-Schmidt", "Bures"])
|
|
def test_tuple_dims(self, dims, method):
|
|
"""Test random_density_matrix {method} method is valid with dims {dims}."""
|
|
value = random_density_matrix(dims, method=method)
|
|
self.assertIsInstance(value, DensityMatrix)
|
|
self.assertTrue(value.is_valid())
|
|
self.assertEqual(value.dims(), dims)
|
|
|
|
@combine(dim=[2, 3, 4, 5], method=["Hilbert-Schmidt", "Bures"])
|
|
def test_int_dims(self, dim, method):
|
|
"""Test random_density_matrix {method} method is valid with dims {dim}."""
|
|
value = random_density_matrix(dim, method=method)
|
|
self.assertIsInstance(value, DensityMatrix)
|
|
self.assertTrue(value.is_valid())
|
|
self.assertEqual(np.prod(value.dims()), dim)
|
|
|
|
@combine(method=["Hilbert-Schmidt", "Bures"])
|
|
def test_fixed_seed(self, method):
|
|
"""Test fixing seed fixes output ({method} method)"""
|
|
seed = 1532
|
|
value1 = random_density_matrix(4, method=method, seed=seed)
|
|
value2 = random_density_matrix(4, method=method, seed=seed)
|
|
self.assertEqual(value1, value2)
|
|
|
|
@combine(method=["Hilbert-Schmidt", "Bures"])
|
|
def test_not_global_seed(self, method):
|
|
"""Test fixing seed is locally scoped ({method} method)."""
|
|
seed = 314159
|
|
test_cases = 100
|
|
random_density_matrix(2, method=method, seed=seed)
|
|
rng_before = np.random.randint(1000, size=test_cases)
|
|
random_density_matrix(2, method=method, seed=seed)
|
|
rng_after = np.random.randint(1000, size=test_cases)
|
|
self.assertFalse(np.all(rng_before == rng_after))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|