mirror of https://github.com/Qiskit/qiskit.git
168 lines
6.5 KiB
Python
168 lines
6.5 KiB
Python
# This code is part of Qiskit.
|
|
#
|
|
# (C) Copyright IBM 2017, 2021.
|
|
#
|
|
# This code is licensed under the Apache License, Version 2.0. You may
|
|
# obtain a copy of this license in the LICENSE.txt file in the root directory
|
|
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
|
|
#
|
|
# Any modifications or derivative works of this code must retain this
|
|
# copyright notice, and modified files need to carry a notice indicating
|
|
# that they have been altered from the originals.
|
|
|
|
"""Test multiplier circuits."""
|
|
|
|
import unittest
|
|
import re
|
|
import numpy as np
|
|
from ddt import ddt, data, unpack
|
|
|
|
from qiskit import transpile
|
|
from qiskit.circuit import QuantumCircuit
|
|
from qiskit.quantum_info import Statevector
|
|
from qiskit.circuit.library import (
|
|
RGQFTMultiplier,
|
|
HRSCumulativeMultiplier,
|
|
CDKMRippleCarryAdder,
|
|
DraperQFTAdder,
|
|
VBERippleCarryAdder,
|
|
MultiplierGate,
|
|
)
|
|
from qiskit.transpiler.passes import HighLevelSynthesis, HLSConfig
|
|
from qiskit.synthesis.arithmetic import multiplier_qft_r17, multiplier_cumulative_h18
|
|
from test import QiskitTestCase # pylint: disable=wrong-import-order
|
|
|
|
|
|
@ddt
|
|
class TestMultiplier(QiskitTestCase):
|
|
"""Test the multiplier circuits."""
|
|
|
|
def assertMultiplicationIsCorrect(
|
|
self, num_state_qubits: int, num_result_qubits: int, multiplier: QuantumCircuit
|
|
):
|
|
"""Assert that multiplier correctly implements the product.
|
|
|
|
Args:
|
|
num_state_qubits: The number of bits in the numbers that are multiplied.
|
|
num_result_qubits: The number of qubits to limit the output to with modulo.
|
|
multiplier: The circuit performing the multiplication of two numbers with
|
|
``num_state_qubits`` bits.
|
|
"""
|
|
circuit = QuantumCircuit(*multiplier.qregs)
|
|
|
|
# create equal superposition
|
|
circuit.h(range(2 * num_state_qubits))
|
|
|
|
# apply multiplier circuit
|
|
circuit.compose(multiplier, inplace=True)
|
|
|
|
# obtain the statevector and the probabilities, we don't trace out the ancilla qubits
|
|
# as we verify that all ancilla qubits have been uncomputed to state 0 again
|
|
tqc = transpile(circuit, basis_gates=["h", "p", "cp", "rz", "cx", "ccx", "swap"])
|
|
statevector = Statevector(tqc)
|
|
probabilities = statevector.probabilities()
|
|
pad = "0" * circuit.num_ancillas # state of the ancillas
|
|
|
|
# compute the expected results
|
|
expectations = np.zeros_like(probabilities)
|
|
num_bits_product = num_state_qubits * 2
|
|
# iterate over all possible inputs
|
|
for x in range(2**num_state_qubits):
|
|
for y in range(2**num_state_qubits):
|
|
# compute the product
|
|
product = x * y % (2**num_result_qubits)
|
|
# compute correct index in statevector
|
|
bin_x = bin(x)[2:].zfill(num_state_qubits)
|
|
bin_y = bin(y)[2:].zfill(num_state_qubits)
|
|
bin_res = bin(product)[2:].zfill(num_bits_product)
|
|
bin_index = pad + bin_res + bin_y + bin_x
|
|
index = int(bin_index, 2)
|
|
expectations[index] += 1 / 2 ** (2 * num_state_qubits)
|
|
np.testing.assert_array_almost_equal(expectations, probabilities)
|
|
|
|
@data(
|
|
(3, RGQFTMultiplier),
|
|
(3, RGQFTMultiplier, 5),
|
|
(3, RGQFTMultiplier, 4),
|
|
(3, RGQFTMultiplier, 3),
|
|
(3, multiplier_qft_r17),
|
|
(3, multiplier_qft_r17, 5),
|
|
(3, multiplier_qft_r17, 4),
|
|
(3, multiplier_qft_r17, 3),
|
|
(3, HRSCumulativeMultiplier),
|
|
(3, HRSCumulativeMultiplier, 5),
|
|
(3, HRSCumulativeMultiplier, 4),
|
|
(3, HRSCumulativeMultiplier, 3),
|
|
(3, multiplier_cumulative_h18),
|
|
(3, multiplier_cumulative_h18, 5),
|
|
(3, multiplier_cumulative_h18, 4),
|
|
(3, multiplier_cumulative_h18, 3),
|
|
(3, HRSCumulativeMultiplier, None, CDKMRippleCarryAdder),
|
|
(3, HRSCumulativeMultiplier, None, DraperQFTAdder),
|
|
(3, HRSCumulativeMultiplier, None, VBERippleCarryAdder),
|
|
)
|
|
@unpack
|
|
def test_multiplication(self, num_state_qubits, multiplier, num_result_qubits=None, adder=None):
|
|
"""Test multiplication for all implemented multipliers."""
|
|
if num_result_qubits is None:
|
|
num_result_qubits = 2 * num_state_qubits
|
|
if adder is not None:
|
|
adder = adder(num_state_qubits, kind="half")
|
|
multiplier = multiplier(num_state_qubits, num_result_qubits, adder=adder)
|
|
else:
|
|
multiplier = multiplier(num_state_qubits, num_result_qubits)
|
|
|
|
self.assertMultiplicationIsCorrect(num_state_qubits, num_result_qubits, multiplier)
|
|
|
|
@data(
|
|
(RGQFTMultiplier, -1),
|
|
(HRSCumulativeMultiplier, -1),
|
|
(RGQFTMultiplier, 0, 0),
|
|
(HRSCumulativeMultiplier, 0, 0),
|
|
(RGQFTMultiplier, 0, 1),
|
|
(HRSCumulativeMultiplier, 0, 1),
|
|
(RGQFTMultiplier, 1, 0),
|
|
(HRSCumulativeMultiplier, 1, 0),
|
|
(RGQFTMultiplier, 3, 2),
|
|
(HRSCumulativeMultiplier, 3, 2),
|
|
(RGQFTMultiplier, 3, 7),
|
|
(HRSCumulativeMultiplier, 3, 7),
|
|
)
|
|
@unpack
|
|
def test_raises_on_wrong_num_bits(self, multiplier, num_state_qubits, num_result_qubits=None):
|
|
"""Test an error is raised for a bad number of state or result qubits."""
|
|
with self.assertRaises(ValueError):
|
|
_ = multiplier(num_state_qubits, num_result_qubits)
|
|
|
|
def test_modular_cumulative_multiplier_custom_adder(self):
|
|
"""Test an error is raised when a custom adder is used with modular cumulative multiplier."""
|
|
with self.assertRaises(NotImplementedError):
|
|
_ = HRSCumulativeMultiplier(3, 3, adder=VBERippleCarryAdder(3))
|
|
|
|
def test_plugins(self):
|
|
"""Test setting the HLS plugins for the modular adder."""
|
|
|
|
# all gates with the plugins we check, including an expected operation
|
|
plugins = [("cumulative_h18", "ccircuit-.*"), ("qft_r17", "qft")]
|
|
|
|
num_state_qubits = 2
|
|
|
|
for plugin, expected_op in plugins:
|
|
with self.subTest(plugin=plugin):
|
|
multiplier = MultiplierGate(num_state_qubits)
|
|
|
|
circuit = QuantumCircuit(multiplier.num_qubits)
|
|
circuit.append(multiplier, range(multiplier.num_qubits))
|
|
|
|
hls_config = HLSConfig(Multiplier=[plugin])
|
|
hls = HighLevelSynthesis(hls_config=hls_config)
|
|
|
|
synth = hls(circuit)
|
|
ops = set(synth.count_ops().keys())
|
|
|
|
self.assertTrue(any(re.match(expected_op, op) for op in ops))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|