qiskit/test/python/circuit/classical/test_expr_constructors.py

920 lines
33 KiB
Python

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
import ddt
from qiskit.circuit import Clbit, ClassicalRegister, Duration
from qiskit.circuit.classical import expr, types
from test import QiskitTestCase # pylint: disable=wrong-import-order
@ddt.ddt
class TestExprConstructors(QiskitTestCase):
def test_lift_legacy_condition(self):
cr = ClassicalRegister(3, "c")
clbit = Clbit()
cond = (cr, 7)
self.assertEqual(
expr.lift_legacy_condition(cond),
expr.Binary(
expr.Binary.Op.EQUAL,
expr.Var(cr, types.Uint(cr.size)),
expr.Value(7, types.Uint(cr.size)),
types.Bool(),
),
)
cond = (cr, 255)
self.assertEqual(
expr.lift_legacy_condition(cond),
expr.Binary(
expr.Binary.Op.EQUAL,
expr.Cast(expr.Var(cr, types.Uint(cr.size)), types.Uint(8), implicit=True),
expr.Value(255, types.Uint(8)),
types.Bool(),
),
)
cond = (clbit, False)
self.assertEqual(
expr.lift_legacy_condition(cond),
expr.Unary(
expr.Unary.Op.LOGIC_NOT,
expr.Var(clbit, types.Bool()),
types.Bool(),
),
)
cond = (clbit, True)
self.assertEqual(
expr.lift_legacy_condition(cond),
expr.Var(clbit, types.Bool()),
)
def test_value_lifts_qiskit_scalars(self):
cr = ClassicalRegister(3, "c")
self.assertEqual(expr.lift(cr), expr.Var(cr, types.Uint(cr.size)))
clbit = Clbit()
self.assertEqual(expr.lift(clbit), expr.Var(clbit, types.Bool()))
duration = Duration.dt(1000)
self.assertEqual(expr.lift(duration), expr.Value(duration, types.Duration()))
def test_value_lifts_python_builtins(self):
self.assertEqual(expr.lift(True), expr.Value(True, types.Bool()))
self.assertEqual(expr.lift(False), expr.Value(False, types.Bool()))
self.assertEqual(expr.lift(7), expr.Value(7, types.Uint(3)))
self.assertEqual(expr.lift(7.0), expr.Value(7.0, types.Float()))
def test_value_ensures_nonzero_width(self):
self.assertEqual(expr.lift(0), expr.Value(0, types.Uint(1)))
def test_value_type_representation(self):
self.assertEqual(expr.lift(5), expr.Value(5, types.Uint((5).bit_length())))
self.assertEqual(expr.lift(5, types.Uint(8)), expr.Value(5, types.Uint(8)))
cr = ClassicalRegister(3, "c")
self.assertEqual(expr.lift(cr, types.Uint(8)), expr.Var(cr, types.Uint(8)))
def test_value_does_not_allow_downcast(self):
with self.assertRaisesRegex(TypeError, "the explicit type .* is not suitable"):
expr.lift(0xFF, types.Uint(2))
with self.assertRaisesRegex(TypeError, "the explicit type .* is not suitable"):
expr.lift(1.1, types.Uint(2))
def test_value_rejects_bad_values(self):
with self.assertRaisesRegex(TypeError, "failed to infer a type"):
expr.lift("1")
with self.assertRaisesRegex(ValueError, "cannot represent a negative value"):
expr.lift(-1)
def test_cast_adds_explicit_nodes(self):
"""A specific request to add a cast in means that we should respect that in the type tree,
even if the cast is a no-op."""
base = expr.Value(5, types.Uint(8))
self.assertEqual(
expr.cast(base, types.Uint(8)), expr.Cast(base, types.Uint(8), implicit=False)
)
def test_cast_allows_lossy_downcasting(self):
"""An explicit 'cast' call should allow lossy casts to be performed."""
base = expr.Value(5, types.Uint(16))
self.assertEqual(
expr.cast(base, types.Uint(8)), expr.Cast(base, types.Uint(8), implicit=False)
)
self.assertEqual(
expr.cast(base, types.Bool()), expr.Cast(base, types.Bool(), implicit=False)
)
self.assertEqual(
expr.cast(base, types.Float()), expr.Cast(base, types.Float(), implicit=False)
)
@ddt.data(
(expr.bit_not, ClassicalRegister(3)),
(expr.logic_not, ClassicalRegister(3)),
(expr.logic_not, False),
(expr.logic_not, Clbit()),
)
@ddt.unpack
def test_unary_functions_lift_scalars(self, function, scalar):
self.assertEqual(function(scalar), function(expr.lift(scalar)))
def test_bit_not_explicit(self):
cr = ClassicalRegister(3)
self.assertEqual(
expr.bit_not(cr),
expr.Unary(
expr.Unary.Op.BIT_NOT, expr.Var(cr, types.Uint(cr.size)), types.Uint(cr.size)
),
)
clbit = Clbit()
self.assertEqual(
expr.bit_not(clbit),
expr.Unary(expr.Unary.Op.BIT_NOT, expr.Var(clbit, types.Bool()), types.Bool()),
)
@ddt.data(expr.bit_not)
def test_unary_bitwise_forbidden(self, function):
with self.assertRaisesRegex(TypeError, "cannot apply"):
function(7.0)
with self.assertRaisesRegex(TypeError, "cannot apply"):
function(Duration.dt(1000))
def test_logic_not_explicit(self):
cr = ClassicalRegister(3)
self.assertEqual(
expr.logic_not(cr),
expr.Unary(
expr.Unary.Op.LOGIC_NOT,
expr.Cast(expr.Var(cr, types.Uint(cr.size)), types.Bool(), implicit=True),
types.Bool(),
),
)
clbit = Clbit()
self.assertEqual(
expr.logic_not(clbit),
expr.Unary(expr.Unary.Op.LOGIC_NOT, expr.Var(clbit, types.Bool()), types.Bool()),
)
@ddt.data(expr.logic_not)
def test_unary_logical_forbidden(self, function):
with self.assertRaisesRegex(TypeError, "cannot apply"):
function(7.0)
with self.assertRaisesRegex(TypeError, "cannot apply"):
function(Duration.dt(1000))
@ddt.data(
(expr.bit_and, ClassicalRegister(3), ClassicalRegister(3)),
(expr.bit_or, ClassicalRegister(3), ClassicalRegister(3)),
(expr.bit_xor, ClassicalRegister(3), ClassicalRegister(3)),
(expr.logic_and, Clbit(), True),
(expr.logic_or, False, ClassicalRegister(3)),
(expr.equal, ClassicalRegister(8), 255),
(expr.not_equal, ClassicalRegister(8), 255),
(expr.less, ClassicalRegister(3), 6),
(expr.less_equal, ClassicalRegister(3), 5),
(expr.greater, 4, ClassicalRegister(3)),
(expr.greater_equal, ClassicalRegister(3), 5),
(expr.add, ClassicalRegister(3), 6),
(expr.sub, ClassicalRegister(3), 5),
(expr.mul, 4, ClassicalRegister(3)),
(expr.div, ClassicalRegister(3), 5),
(expr.equal, 8.0, 255.0),
(expr.not_equal, 8.0, 255.0),
(expr.less, 3.0, 6.0),
(expr.less_equal, 3.0, 5.0),
(expr.greater, 4.0, 3.0),
(expr.greater_equal, 3.0, 5.0),
(expr.add, 3.0, 6.0),
(expr.sub, 3.0, 5.0),
(expr.mul, 4.0, 3.0),
(expr.div, 3.0, 5.0),
(expr.equal, Duration.dt(1000), Duration.dt(1000)),
(expr.not_equal, Duration.dt(1000), Duration.dt(1000)),
(expr.less, Duration.dt(1000), Duration.dt(1000)),
(expr.less_equal, Duration.dt(1000), Duration.dt(1000)),
(expr.greater, Duration.dt(1000), Duration.dt(1000)),
(expr.greater_equal, Duration.dt(1000), Duration.dt(1000)),
(expr.add, Duration.dt(1000), Duration.dt(1000)),
(expr.sub, Duration.dt(1000), Duration.dt(1000)),
(expr.div, Duration.dt(1000), Duration.dt(1000)),
)
@ddt.unpack
def test_binary_functions_lift_scalars(self, function, left, right):
self.assertEqual(function(left, right), function(expr.lift(left), right))
self.assertEqual(function(left, right), function(left, expr.lift(right)))
self.assertEqual(function(left, right), function(expr.lift(left), expr.lift(right)))
@ddt.data(
(expr.bit_and, expr.Binary.Op.BIT_AND),
(expr.bit_or, expr.Binary.Op.BIT_OR),
(expr.bit_xor, expr.Binary.Op.BIT_XOR),
)
@ddt.unpack
def test_binary_bitwise_explicit(self, function, opcode):
cr = ClassicalRegister(8, "c")
self.assertEqual(
function(cr, 255),
expr.Binary(
opcode, expr.Var(cr, types.Uint(8)), expr.Value(255, types.Uint(8)), types.Uint(8)
),
)
self.assertFalse(function(cr, 255).const)
self.assertEqual(
function(255, cr),
expr.Binary(
opcode, expr.Value(255, types.Uint(8)), expr.Var(cr, types.Uint(8)), types.Uint(8)
),
)
self.assertFalse(function(255, cr).const)
clbit = Clbit()
self.assertEqual(
function(True, clbit),
expr.Binary(
opcode,
expr.Value(True, types.Bool()),
expr.Var(clbit, types.Bool()),
types.Bool(),
),
)
self.assertFalse(function(True, clbit).const)
self.assertEqual(
function(clbit, False),
expr.Binary(
opcode,
expr.Var(clbit, types.Bool()),
expr.Value(False, types.Bool()),
types.Bool(),
),
)
self.assertFalse(function(clbit, True).const)
self.assertEqual(
function(255, 255),
expr.Binary(
opcode,
expr.Value(255, types.Uint(8)),
expr.Value(255, types.Uint(8)),
types.Uint(8),
),
)
self.assertTrue(function(255, 255).const)
@ddt.data(
(expr.bit_and, expr.Binary.Op.BIT_AND),
(expr.bit_or, expr.Binary.Op.BIT_OR),
(expr.bit_xor, expr.Binary.Op.BIT_XOR),
)
@ddt.unpack
def test_binary_bitwise_uint_inference(self, function, opcode):
"""The binary bitwise functions have specialized inference for the widths of integer
literals, since the bitwise functions require the operands to already be of exactly the same
width without promotion."""
cr = ClassicalRegister(8, "c")
self.assertEqual(
function(cr, 5),
expr.Binary(
opcode,
expr.Var(cr, types.Uint(8)),
expr.Value(5, types.Uint(8)), # Note the inference should be Uint(8) not Uint(3).
types.Uint(8),
),
)
self.assertEqual(
function(5, cr),
expr.Binary(
opcode,
expr.Value(5, types.Uint(8)),
expr.Var(cr, types.Uint(8)),
types.Uint(8),
),
)
# Inference between two integer literals is "best effort". This behavior isn't super
# important to maintain if we want to change the expression system.
self.assertEqual(
function(5, 255),
expr.Binary(
opcode,
expr.Value(5, types.Uint(8)),
expr.Value(255, types.Uint(8)),
types.Uint(8),
),
)
@ddt.data(expr.bit_and, expr.bit_or, expr.bit_xor)
def test_binary_bitwise_forbidden(self, function):
with self.assertRaisesRegex(TypeError, "invalid types"):
function(ClassicalRegister(3, "c"), Clbit())
with self.assertRaisesRegex(TypeError, "invalid types"):
function(3.0, 3.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(3, 3.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(3.0, 3)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Duration.dt(1000), Duration.dt(1000))
# Unlike most other functions, the bitwise functions should error if the two bit-like types
# aren't of the same width, except for the special inference for integer literals.
with self.assertRaisesRegex(TypeError, "binary bitwise operations .* same width"):
function(ClassicalRegister(3, "a"), ClassicalRegister(5, "b"))
@ddt.data(
(expr.logic_and, expr.Binary.Op.LOGIC_AND),
(expr.logic_or, expr.Binary.Op.LOGIC_OR),
)
@ddt.unpack
def test_binary_logical_explicit(self, function, opcode):
cr = ClassicalRegister(8, "c")
clbit = Clbit()
self.assertEqual(
function(cr, clbit),
expr.Binary(
opcode,
expr.Cast(expr.Var(cr, types.Uint(cr.size)), types.Bool(), implicit=True),
expr.Var(clbit, types.Bool()),
types.Bool(),
),
)
self.assertFalse(function(cr, clbit).const)
self.assertEqual(
function(cr, 3),
expr.Binary(
opcode,
expr.Cast(expr.Var(cr, types.Uint(cr.size)), types.Bool(), implicit=True),
expr.Cast(expr.Value(3, types.Uint(2)), types.Bool(), implicit=True),
types.Bool(),
),
)
self.assertFalse(function(cr, 3).const)
self.assertEqual(
function(False, clbit),
expr.Binary(
opcode,
expr.Value(False, types.Bool()),
expr.Var(clbit, types.Bool()),
types.Bool(),
),
)
self.assertFalse(function(False, clbit).const)
@ddt.data(expr.logic_and, expr.logic_or)
def test_binary_logic_forbidden(self, function):
with self.assertRaisesRegex(TypeError, "invalid types"):
function(3.0, 3.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(3, 3.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(3.0, 3)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Duration.dt(1000), Duration.dt(1000))
@ddt.data(
(expr.equal, expr.Binary.Op.EQUAL),
(expr.not_equal, expr.Binary.Op.NOT_EQUAL),
)
@ddt.unpack
def test_binary_equal_explicit(self, function, opcode):
cr = ClassicalRegister(8, "c")
clbit = Clbit()
self.assertEqual(
function(cr, 255),
expr.Binary(
opcode, expr.Var(cr, types.Uint(8)), expr.Value(255, types.Uint(8)), types.Bool()
),
)
self.assertFalse(function(cr, 255).const)
self.assertEqual(
function(7, cr),
expr.Binary(
opcode,
expr.Value(7, types.Uint(8)),
expr.Var(cr, types.Uint(8)),
types.Bool(),
),
)
self.assertFalse(function(7, cr).const)
self.assertEqual(
function(clbit, True),
expr.Binary(
opcode,
expr.Var(clbit, types.Bool()),
expr.Value(True, types.Bool()),
types.Bool(),
),
)
self.assertFalse(function(clbit, True).const)
self.assertEqual(
function(expr.lift(7.0), 7.0),
expr.Binary(
opcode,
expr.Value(7.0, types.Float()),
expr.Value(7.0, types.Float()),
types.Bool(),
),
)
self.assertTrue(function(expr.lift(7.0), 7.0).const)
self.assertEqual(
function(expr.lift(Duration.ms(1000)), Duration.s(1)),
expr.Binary(
opcode,
expr.Value(Duration.ms(1000), types.Duration()),
expr.Value(Duration.s(1), types.Duration()),
types.Bool(),
),
)
self.assertTrue(function(expr.lift(Duration.ms(1000)), Duration.s(1)).const)
@ddt.data(expr.equal, expr.not_equal)
def test_binary_equal_forbidden(self, function):
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Clbit(), ClassicalRegister(3, "c"))
with self.assertRaisesRegex(TypeError, "invalid types"):
function(ClassicalRegister(3, "c"), False)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(5, True)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(True, 5.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(5, 5.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(ClassicalRegister(3, "c"), 5.0)
@ddt.data(
(expr.less, expr.Binary.Op.LESS),
(expr.less_equal, expr.Binary.Op.LESS_EQUAL),
(expr.greater, expr.Binary.Op.GREATER),
(expr.greater_equal, expr.Binary.Op.GREATER_EQUAL),
)
@ddt.unpack
def test_binary_relation_explicit(self, function, opcode):
cr = ClassicalRegister(8, "c")
self.assertEqual(
function(cr, 200),
expr.Binary(
opcode, expr.Var(cr, types.Uint(8)), expr.Value(200, types.Uint(8)), types.Bool()
),
)
self.assertFalse(function(cr, 200).const)
self.assertEqual(
function(12, cr),
expr.Binary(
opcode,
expr.Value(12, types.Uint(8)),
expr.Var(cr, types.Uint(8)),
types.Bool(),
),
)
self.assertFalse(function(12, cr).const)
self.assertEqual(
function(expr.lift(12.0, types.Float()), expr.lift(12.0)),
expr.Binary(
opcode,
expr.Value(12.0, types.Float()),
expr.Value(12.0, types.Float()),
types.Bool(),
),
)
self.assertTrue(function(expr.lift(12.0, types.Float()), expr.lift(12.0)).const)
self.assertEqual(
function(
expr.lift(Duration.ms(1000), types.Duration()),
expr.lift(Duration.s(1)),
),
expr.Binary(
opcode,
expr.Value(Duration.ms(1000), types.Duration()),
expr.Value(Duration.s(1), types.Duration()),
types.Bool(),
),
)
self.assertTrue(
function(
expr.lift(Duration.ms(1000), types.Duration()),
expr.lift(Duration.s(1)),
).const
)
@ddt.data(expr.less, expr.less_equal, expr.greater, expr.greater_equal)
def test_binary_relation_forbidden(self, function):
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Clbit(), ClassicalRegister(3, "c"))
with self.assertRaisesRegex(TypeError, "invalid types"):
function(ClassicalRegister(3, "c"), False)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Clbit(), Clbit())
with self.assertRaisesRegex(TypeError, "invalid types"):
function(True, 5.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(5, 5.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(ClassicalRegister(3, "c"), 5.0)
def test_index_explicit(self):
cr = ClassicalRegister(4, "c")
a = expr.Var.new("a", types.Uint(8))
self.assertEqual(
expr.index(cr, 3),
expr.Index(expr.Var(cr, types.Uint(4)), expr.Value(3, types.Uint(2)), types.Bool()),
)
self.assertFalse(expr.index(cr, 3).const)
self.assertEqual(
expr.index(a, cr),
expr.Index(a, expr.Var(cr, types.Uint(4)), types.Bool()),
)
self.assertFalse(expr.index(a, cr).const)
self.assertEqual(
expr.index(255, 1),
expr.Index(expr.Value(255, types.Uint(8)), expr.Value(1, types.Uint(1)), types.Bool()),
)
self.assertTrue(expr.index(255, 1).const)
def test_index_forbidden(self):
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.index(Clbit(), 3)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.index(ClassicalRegister(3, "a"), False)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.index(ClassicalRegister(3, "a"), 1.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.index(0xFFFF, 1.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.index(ClassicalRegister(3, "a"), 1.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.index(Duration.dt(1000), 1)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.index(Duration.dt(1000), Duration.dt(1000))
@ddt.data(
(expr.shift_left, expr.Binary.Op.SHIFT_LEFT),
(expr.shift_right, expr.Binary.Op.SHIFT_RIGHT),
)
@ddt.unpack
def test_shift_explicit(self, function, opcode):
cr = ClassicalRegister(8, "c")
a = expr.Var.new("a", types.Uint(4))
self.assertEqual(
function(cr, 5),
expr.Binary(
opcode, expr.Var(cr, types.Uint(8)), expr.Value(5, types.Uint(3)), types.Uint(8)
),
)
self.assertFalse(function(cr, 5).const)
self.assertEqual(
function(a, cr),
expr.Binary(opcode, a, expr.Var(cr, types.Uint(8)), types.Uint(4)),
)
self.assertFalse(function(a, cr).const)
self.assertEqual(
function(3, 5, types.Uint(8)),
expr.Binary(
opcode, expr.Value(3, types.Uint(8)), expr.Value(5, types.Uint(3)), types.Uint(8)
),
)
self.assertTrue(function(3, 5, types.Uint(8)).const)
@ddt.data(expr.shift_left, expr.shift_right)
def test_shift_forbidden(self, function):
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Clbit(), ClassicalRegister(3, "c"))
with self.assertRaisesRegex(TypeError, "invalid types"):
function(ClassicalRegister(3, "c"), False)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Clbit(), Clbit())
with self.assertRaisesRegex(TypeError, "invalid types"):
function(0xFFFF, 2.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(255.0, 1)
with self.assertRaisesRegex(TypeError, "cannot losslessly represent"):
function(expr.lift(5.0), 3, types.Uint(8))
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Duration.dt(1000), 1)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Duration.dt(1000), Duration.dt(1000))
@ddt.data(
(expr.add, expr.Binary.Op.ADD),
(expr.sub, expr.Binary.Op.SUB),
)
@ddt.unpack
def test_binary_sum_explicit(self, function, opcode):
cr = ClassicalRegister(8, "c")
self.assertEqual(
function(cr, 200),
expr.Binary(
opcode, expr.Var(cr, types.Uint(8)), expr.Value(200, types.Uint(8)), types.Uint(8)
),
)
self.assertFalse(function(cr, 200).const)
self.assertEqual(
function(12, cr),
expr.Binary(
opcode,
expr.Value(12, types.Uint(8)),
expr.Var(cr, types.Uint(8)),
types.Uint(8),
),
)
self.assertFalse(function(12, cr).const)
self.assertEqual(
function(12.5, 2.0),
expr.Binary(
opcode,
expr.Value(12.5, types.Float()),
expr.Value(2.0, types.Float()),
types.Float(),
),
)
self.assertTrue(function(12.5, 2.0).const)
self.assertEqual(
function(
expr.lift(Duration.ms(1000), types.Duration()),
expr.lift(Duration.s(1)),
),
expr.Binary(
opcode,
expr.Value(Duration.ms(1000), types.Duration()),
expr.Value(Duration.s(1), types.Duration()),
types.Duration(),
),
)
self.assertTrue(
function(
expr.lift(Duration.ms(1000), types.Duration()),
expr.lift(Duration.s(1)),
).const
)
@ddt.data(expr.add, expr.sub)
def test_binary_sum_forbidden(self, function):
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Clbit(), ClassicalRegister(3, "c"))
with self.assertRaisesRegex(TypeError, "invalid types"):
function(ClassicalRegister(3, "c"), False)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Clbit(), Clbit())
with self.assertRaisesRegex(TypeError, "invalid types"):
function(0xFFFF, 2.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(255.0, 1)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Duration.dt(1000), 1)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Duration.dt(1000), 1.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
function(Duration.dt(1000), expr.lift(1.0))
def test_mul_explicit(self):
cr = ClassicalRegister(8, "c")
self.assertEqual(
expr.mul(cr, 200),
expr.Binary(
expr.Binary.Op.MUL,
expr.Var(cr, types.Uint(8)),
expr.Value(200, types.Uint(8)),
types.Uint(8),
),
)
self.assertFalse(expr.mul(cr, 200).const)
self.assertEqual(
expr.mul(12, cr),
expr.Binary(
expr.Binary.Op.MUL,
expr.Value(12, types.Uint(8)),
expr.Var(cr, types.Uint(8)),
types.Uint(8),
),
)
self.assertFalse(expr.mul(12, cr).const)
self.assertEqual(
expr.mul(expr.lift(12), cr),
expr.Binary(
expr.Binary.Op.MUL,
# Explicit cast required to get from Uint(4) to Uint(8)
expr.Cast(expr.Value(12, types.Uint(4)), types.Uint(8), implicit=False),
expr.Var(cr, types.Uint(8)),
types.Uint(8),
),
)
self.assertFalse(expr.mul(12, cr).const)
self.assertEqual(
expr.mul(expr.lift(12, types.Uint(8)), expr.lift(12)),
expr.Binary(
expr.Binary.Op.MUL,
expr.Value(12, types.Uint(8)),
expr.Cast(
expr.Value(12, types.Uint(4)),
types.Uint(8),
implicit=False,
),
types.Uint(8),
),
)
self.assertTrue(expr.mul(expr.lift(12, types.Uint(8)), expr.lift(12)).const)
self.assertEqual(
expr.mul(expr.lift(12.0, types.Float()), expr.lift(12.0)),
expr.Binary(
expr.Binary.Op.MUL,
expr.Value(12.0, types.Float()),
expr.Value(12.0, types.Float()),
types.Float(),
),
)
self.assertTrue(expr.mul(expr.lift(12.0, types.Float()), expr.lift(12.0)).const)
self.assertEqual(
expr.mul(Duration.ms(1000), 2.0),
expr.Binary(
expr.Binary.Op.MUL,
expr.Value(Duration.ms(1000), types.Duration()),
expr.Value(2.0, types.Float()),
types.Duration(),
),
)
self.assertTrue(expr.mul(Duration.ms(1000), 2.0).const)
self.assertEqual(
expr.mul(2.0, Duration.ms(1000)),
expr.Binary(
expr.Binary.Op.MUL,
expr.Value(2.0, types.Float()),
expr.Value(Duration.ms(1000), types.Duration()),
types.Duration(),
),
)
self.assertTrue(expr.mul(2.0, Duration.ms(1000)).const)
self.assertEqual(
expr.mul(2, Duration.ms(1000)),
expr.Binary(
expr.Binary.Op.MUL,
expr.Value(2, types.Uint(2)),
expr.Value(Duration.ms(1000), types.Duration()),
types.Duration(),
),
)
self.assertTrue(expr.mul(2, Duration.ms(1000)).const)
def test_mul_forbidden(self):
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.mul(Clbit(), ClassicalRegister(3, "c"))
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.mul(ClassicalRegister(3, "c"), False)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.mul(Clbit(), Clbit())
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.mul(0xFFFF, 2.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.mul(255.0, 1)
with self.assertRaisesRegex(TypeError, "cannot multiply two durations"):
expr.mul(Duration.dt(1000), Duration.dt(1000))
def test_div_explicit(self):
cr = ClassicalRegister(8, "c")
self.assertEqual(
expr.div(cr, 200),
expr.Binary(
expr.Binary.Op.DIV,
expr.Var(cr, types.Uint(8)),
expr.Value(200, types.Uint(8)),
types.Uint(8),
),
)
self.assertFalse(expr.div(cr, 200).const)
self.assertEqual(
expr.div(12, cr),
expr.Binary(
expr.Binary.Op.DIV,
expr.Value(12, types.Uint(8)),
expr.Var(cr, types.Uint(8)),
types.Uint(8),
),
)
self.assertFalse(expr.div(12, cr).const)
self.assertEqual(
expr.div(expr.lift(12), cr),
expr.Binary(
expr.Binary.Op.DIV,
# Explicit cast required to get from Uint(4) to Uint(8)
expr.Cast(expr.Value(12, types.Uint(4)), types.Uint(8), implicit=False),
expr.Var(cr, types.Uint(8)),
types.Uint(8),
),
)
self.assertFalse(expr.div(expr.lift(12), cr).const)
self.assertEqual(
expr.div(expr.lift(12, types.Uint(8)), expr.lift(12)),
expr.Binary(
expr.Binary.Op.DIV,
expr.Value(12, types.Uint(8)),
expr.Cast(
expr.Value(12, types.Uint(4)),
types.Uint(8),
implicit=False,
),
types.Uint(8),
),
)
self.assertTrue(expr.div(expr.lift(12, types.Uint(8)), expr.lift(12)).const)
self.assertEqual(
expr.div(expr.lift(12.0, types.Float()), expr.lift(12.0)),
expr.Binary(
expr.Binary.Op.DIV,
expr.Value(12.0, types.Float()),
expr.Value(12.0, types.Float()),
types.Float(),
),
)
self.assertTrue(expr.div(expr.lift(12.0, types.Float()), expr.lift(12.0)).const)
self.assertEqual(
expr.div(Duration.ms(1000), 2.0),
expr.Binary(
expr.Binary.Op.DIV,
expr.Value(Duration.ms(1000), types.Duration()),
expr.Value(2.0, types.Float()),
types.Duration(),
),
)
self.assertTrue(expr.div(Duration.ms(1000), 2.0).const)
self.assertEqual(
expr.div(Duration.ms(1000), 2),
expr.Binary(
expr.Binary.Op.DIV,
expr.Value(Duration.ms(1000), types.Duration()),
expr.Value(2, types.Uint(2)),
types.Duration(),
),
)
self.assertTrue(expr.div(Duration.ms(1000), 2).const)
self.assertEqual(
expr.div(Duration.ms(1000), Duration.ms(1000)),
expr.Binary(
expr.Binary.Op.DIV,
expr.Value(Duration.ms(1000), types.Duration()),
expr.Value(Duration.ms(1000), types.Duration()),
types.Float(),
),
)
self.assertTrue(expr.div(Duration.ms(1000), Duration.ms(1000)).const)
def test_div_forbidden(self):
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.div(Clbit(), ClassicalRegister(3, "c"))
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.div(ClassicalRegister(3, "c"), False)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.div(Clbit(), Clbit())
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.div(0xFFFF, 2.0)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.div(255.0, 1)
with self.assertRaisesRegex(TypeError, "invalid types"):
expr.div(255.0, Duration.dt(1000))