[Python] Update bindings to work with new result type inference.

Python bindings now attempt to perform result type inference, so we no
longer have to. This is great in general, but we need to complicate
our logic to support the NamedValueOpView functionality.

In cases where all operands are known a-priori, we defer to
ODS-generated result type inference when possible. In these cases, the
result type is removed from user-facing APIs, and may require
downstream updates to remove the now-redundant result type at call
sites.

In cases where not all operands are known and we are building
backedges, we take care to ensure the ODS-generated result type
inference is used. Ops that may build backedges and do not have result
type inference now pass a special flag needs_result_type.

Closes https://github.com/llvm/circt/issues/1995.
This commit is contained in:
Mike Urbach 2021-11-04 16:13:58 -06:00
parent 08ed4ff948
commit 3fbedb5ecf
9 changed files with 59 additions and 50 deletions

View File

@ -18,10 +18,10 @@ with Context() as ctx, Location.unknown():
def build(module):
# CHECK: %[[CONST:.+]] = hw.constant 1 : i32
const = hw.ConstantOp(i32, IntegerAttr.get(i32, 1))
const = hw.ConstantOp(IntegerAttr.get(i32, 1))
# CHECK: %[[BIT:.+]] = hw.constant true
bit = hw.ConstantOp(i1, IntegerAttr.get(i1, 1))
bit = hw.ConstantOp(IntegerAttr.get(i1, 1))
# CHECK: comb.extract %[[CONST]] from 14
comb.ExtractOp.create(14, i32, const.result)
@ -30,7 +30,7 @@ with Context() as ctx, Location.unknown():
connect(extract.input, const.result)
# CHECK: comb.parity %[[CONST]]
comb.ParityOp.create(const.result, result_type=i32)
comb.ParityOp.create(const.result)
# CHECK: comb.parity %[[CONST]]
parity = comb.ParityOp.create(result_type=i32)
connect(parity.input, const.result)
@ -173,10 +173,10 @@ with Context() as ctx, Location.unknown():
comb.XorOp.create(const.result, const.result)
# CHECK: comb.concat %[[CONST]], %[[CONST]]
comb.ConcatOp.create(i32, const.result, const.result)
comb.ConcatOp.create(const.result, const.result)
# CHECK: comb.mux %[[BIT]], %[[CONST]], %[[CONST]]
comb.MuxOp.create(i32, bit.result, const.result, const.result)
comb.MuxOp.create(bit.result, const.result, const.result)
hw.HWModuleOp(name="test", body_builder=build)

View File

@ -1,5 +1,5 @@
# REQUIRES: bindings_python
# RUN: %PYTHON% %s | FileCheck %s
# RUN: %PYTHON% %s 2>&1 | FileCheck %s
import circt
from circt.dialects import comb, hw
@ -16,19 +16,17 @@ with Context() as ctx, Location.unknown():
with InsertionPoint(m.body):
def build(module):
const1 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1))
const2 = hw.ConstantOp(i31, IntegerAttr.get(i31, 1))
const1 = hw.ConstantOp(IntegerAttr.get(i32, 1))
const2 = hw.ConstantOp(IntegerAttr.get(i31, 1))
# CHECK: expected same input port types, but received [Type(i32), Type(i31)]
try:
comb.DivSOp.create(const1.result, const2.result)
except TypeError as e:
print(e)
# CHECK: op requires all operands to have the same type
div = comb.DivSOp.create(const1.result, const2.result)
div.opview.verify()
# CHECK: result type must be specified
# CHECK: result type cannot be None
try:
comb.DivSOp.create()
except TypeError as e:
except ValueError as e:
print(e)
hw.HWModuleOp(name="test", body_builder=build)

View File

@ -17,7 +17,7 @@ with Context() as ctx, Location.unknown():
with InsertionPoint(m.body):
def build(module):
constI32 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1))
constI32 = hw.ConstantOp(IntegerAttr.get(i32, 1))
constI1 = hw.ConstantOp.create(i1, 1)
# CHECK: All arguments must be the same type to create an array

View File

@ -4,16 +4,6 @@ from circt.support import NamedValueOpView, get_value
from mlir.ir import IntegerAttr, IntegerType, Type
def infer_result_type(operands):
types = list(map(lambda arg: get_value(arg).type, operands))
if not types:
raise TypeError("result type must be specified")
all_equal = all(type == types[0] for type in types)
if not all_equal:
raise TypeError(f"expected same input port types, but received {types}")
return types[0]
# Builder base classes for non-variadic unary and binary ops.
class UnaryOpBuilder(NamedValueOpView):
@ -59,10 +49,6 @@ def BinaryOp(base):
@classmethod
def create(cls, lhs=None, rhs=None, result_type=None):
if not result_type:
if not lhs and not rhs:
raise TypeError("result type must be specified")
result_type = infer_result_type([lhs, rhs])
mapping = {}
if lhs:
mapping["lhs"] = lhs
@ -80,8 +66,7 @@ def VariadicOp(base):
@classmethod
def create(cls, *args):
result_type = infer_result_type(args)
return cls(result_type, [get_value(a) for a in args])
return cls([get_value(a) for a in args])
return _Class
@ -105,7 +90,10 @@ class ExtractOp:
@staticmethod
def create(low_bit, result_type, input=None):
mapping = {"input": input} if input else {}
return ExtractOpBuilder(low_bit, result_type, mapping)
return ExtractOpBuilder(low_bit,
result_type,
mapping,
needs_result_type=True)
@UnaryOp
@ -113,9 +101,12 @@ class ParityOp:
pass
@UnaryOp
class SExtOp:
pass
@classmethod
def create(cls, input=None, result_type=None):
mapping = {"input": input} if input else {}
return UnaryOpBuilder(cls, result_type, mapping, needs_result_type=True)
# Sugar classes for the various non-variadic binary ops.
@ -185,14 +176,12 @@ class XorOp:
pass
# Sugar classes for miscellaneous ops.
@VariadicOp
class ConcatOp:
@classmethod
def create(cls, result_type, *args, **kwargs):
return cls(result_type, args, **kwargs)
pass
# Sugar classes for miscellaneous ops.
@CreatableOp
class MuxOp:
pass

View File

@ -78,6 +78,7 @@ class InstanceBuilder(support.NamedValueOpView):
input_port_mapping,
pre_args,
post_args,
needs_result_type=True,
loc=loc,
ip=ip)
@ -325,7 +326,7 @@ class ConstantOp:
@staticmethod
def create(data_type, value):
return hw.ConstantOp(data_type, IntegerAttr.get(data_type, value))
return hw.ConstantOp(IntegerAttr.get(data_type, value))
class ArrayGetOp:

View File

@ -69,4 +69,5 @@ class CompRegOp:
kwargs,
reset=reset,
reset_value=reset_value,
name=name)
name=name,
needs_result_type=True)

View File

@ -32,12 +32,15 @@ def CompareOp(predicate):
@staticmethod
def create(lhs=None, rhs=None):
result_type = IntegerType.get_signless(1)
mapping = {}
if lhs:
mapping["lhs"] = lhs
if rhs:
mapping["rhs"] = rhs
if len(mapping) == 0:
result_type = IntegerType.get_signless(1)
else:
result_type = None
return ICmpOpBuilder(predicate, result_type, mapping)
return _Class

View File

@ -19,7 +19,7 @@ def reg(value, clock, reset=None, reset_value=None, name=None):
if reset:
if not reset_value:
zero = IntegerAttr.get(value_type, 0)
reset_value = hw.ConstantOp(value_type, zero).result
reset_value = hw.ConstantOp(zero).result
return CompRegOp.create(value_type,
input=value,
clk=clock,

View File

@ -269,11 +269,20 @@ class NamedValueOpView:
def __init__(self,
cls,
data_type,
input_port_mapping={},
pre_args=[],
post_args=[],
data_type=None,
input_port_mapping=None,
pre_args=None,
post_args=None,
needs_result_type=False,
**kwargs):
# Set defaults
if input_port_mapping is None:
input_port_mapping = {}
if pre_args is None:
pre_args = []
if post_args is None:
post_args = []
# Set result_indices to name each result.
result_names = self.result_names()
result_indices = {}
@ -302,8 +311,16 @@ class NamedValueOpView:
if isinstance(data_type, list):
operand_values = [operand_values]
self.opview = cls(data_type, *pre_args, *operand_values, *post_args,
**kwargs)
# In many cases, result types are inferred, and we do not need to pass
# data_type to the underlying constructor. It must be provided to
# NamedValueOpView in cases where we need to build backedges, but should
# generally not be passed to the underlying constructor in this case. There
# are some oddball ops that must pass it, even when building backedges, and
# these set needs_result_type=True.
if data_type is not None and (needs_result_type or len(backedges) == 0):
pre_args.insert(0, data_type)
self.opview = cls(*pre_args, *operand_values, *post_args, **kwargs)
self.operand_indices = operand_indices
self.result_indices = result_indices
self.backedges = backedges