[Python] Update bindings to work with new result type inference.

Python bindings now attempt to perform result type inference, so we no
longer have to. This is great in general, but we need to complicate
our logic to support the NamedValueOpView functionality.

In cases where all operands are known a-priori, we defer to
ODS-generated result type inference when possible. In these cases, the
result type is removed from user-facing APIs, and may require
downstream updates to remove the now-redundant result type at call
sites.

In cases where not all operands are known and we are building
backedges, we take care to ensure the ODS-generated result type
inference is used. Ops that may build backedges and do not have result
type inference now pass a special flag needs_result_type.

Closes https://github.com/llvm/circt/issues/1995.
This commit is contained in:
Mike Urbach 2021-11-04 16:13:58 -06:00
parent 08ed4ff948
commit 3fbedb5ecf
9 changed files with 59 additions and 50 deletions

View File

@ -18,10 +18,10 @@ with Context() as ctx, Location.unknown():
def build(module): def build(module):
# CHECK: %[[CONST:.+]] = hw.constant 1 : i32 # CHECK: %[[CONST:.+]] = hw.constant 1 : i32
const = hw.ConstantOp(i32, IntegerAttr.get(i32, 1)) const = hw.ConstantOp(IntegerAttr.get(i32, 1))
# CHECK: %[[BIT:.+]] = hw.constant true # CHECK: %[[BIT:.+]] = hw.constant true
bit = hw.ConstantOp(i1, IntegerAttr.get(i1, 1)) bit = hw.ConstantOp(IntegerAttr.get(i1, 1))
# CHECK: comb.extract %[[CONST]] from 14 # CHECK: comb.extract %[[CONST]] from 14
comb.ExtractOp.create(14, i32, const.result) comb.ExtractOp.create(14, i32, const.result)
@ -30,7 +30,7 @@ with Context() as ctx, Location.unknown():
connect(extract.input, const.result) connect(extract.input, const.result)
# CHECK: comb.parity %[[CONST]] # CHECK: comb.parity %[[CONST]]
comb.ParityOp.create(const.result, result_type=i32) comb.ParityOp.create(const.result)
# CHECK: comb.parity %[[CONST]] # CHECK: comb.parity %[[CONST]]
parity = comb.ParityOp.create(result_type=i32) parity = comb.ParityOp.create(result_type=i32)
connect(parity.input, const.result) connect(parity.input, const.result)
@ -173,10 +173,10 @@ with Context() as ctx, Location.unknown():
comb.XorOp.create(const.result, const.result) comb.XorOp.create(const.result, const.result)
# CHECK: comb.concat %[[CONST]], %[[CONST]] # CHECK: comb.concat %[[CONST]], %[[CONST]]
comb.ConcatOp.create(i32, const.result, const.result) comb.ConcatOp.create(const.result, const.result)
# CHECK: comb.mux %[[BIT]], %[[CONST]], %[[CONST]] # CHECK: comb.mux %[[BIT]], %[[CONST]], %[[CONST]]
comb.MuxOp.create(i32, bit.result, const.result, const.result) comb.MuxOp.create(bit.result, const.result, const.result)
hw.HWModuleOp(name="test", body_builder=build) hw.HWModuleOp(name="test", body_builder=build)

View File

@ -1,5 +1,5 @@
# REQUIRES: bindings_python # REQUIRES: bindings_python
# RUN: %PYTHON% %s | FileCheck %s # RUN: %PYTHON% %s 2>&1 | FileCheck %s
import circt import circt
from circt.dialects import comb, hw from circt.dialects import comb, hw
@ -16,19 +16,17 @@ with Context() as ctx, Location.unknown():
with InsertionPoint(m.body): with InsertionPoint(m.body):
def build(module): def build(module):
const1 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1)) const1 = hw.ConstantOp(IntegerAttr.get(i32, 1))
const2 = hw.ConstantOp(i31, IntegerAttr.get(i31, 1)) const2 = hw.ConstantOp(IntegerAttr.get(i31, 1))
# CHECK: expected same input port types, but received [Type(i32), Type(i31)] # CHECK: op requires all operands to have the same type
try: div = comb.DivSOp.create(const1.result, const2.result)
comb.DivSOp.create(const1.result, const2.result) div.opview.verify()
except TypeError as e:
print(e)
# CHECK: result type must be specified # CHECK: result type cannot be None
try: try:
comb.DivSOp.create() comb.DivSOp.create()
except TypeError as e: except ValueError as e:
print(e) print(e)
hw.HWModuleOp(name="test", body_builder=build) hw.HWModuleOp(name="test", body_builder=build)

View File

@ -17,7 +17,7 @@ with Context() as ctx, Location.unknown():
with InsertionPoint(m.body): with InsertionPoint(m.body):
def build(module): def build(module):
constI32 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1)) constI32 = hw.ConstantOp(IntegerAttr.get(i32, 1))
constI1 = hw.ConstantOp.create(i1, 1) constI1 = hw.ConstantOp.create(i1, 1)
# CHECK: All arguments must be the same type to create an array # CHECK: All arguments must be the same type to create an array

View File

@ -4,16 +4,6 @@ from circt.support import NamedValueOpView, get_value
from mlir.ir import IntegerAttr, IntegerType, Type from mlir.ir import IntegerAttr, IntegerType, Type
def infer_result_type(operands):
types = list(map(lambda arg: get_value(arg).type, operands))
if not types:
raise TypeError("result type must be specified")
all_equal = all(type == types[0] for type in types)
if not all_equal:
raise TypeError(f"expected same input port types, but received {types}")
return types[0]
# Builder base classes for non-variadic unary and binary ops. # Builder base classes for non-variadic unary and binary ops.
class UnaryOpBuilder(NamedValueOpView): class UnaryOpBuilder(NamedValueOpView):
@ -59,10 +49,6 @@ def BinaryOp(base):
@classmethod @classmethod
def create(cls, lhs=None, rhs=None, result_type=None): def create(cls, lhs=None, rhs=None, result_type=None):
if not result_type:
if not lhs and not rhs:
raise TypeError("result type must be specified")
result_type = infer_result_type([lhs, rhs])
mapping = {} mapping = {}
if lhs: if lhs:
mapping["lhs"] = lhs mapping["lhs"] = lhs
@ -80,8 +66,7 @@ def VariadicOp(base):
@classmethod @classmethod
def create(cls, *args): def create(cls, *args):
result_type = infer_result_type(args) return cls([get_value(a) for a in args])
return cls(result_type, [get_value(a) for a in args])
return _Class return _Class
@ -105,7 +90,10 @@ class ExtractOp:
@staticmethod @staticmethod
def create(low_bit, result_type, input=None): def create(low_bit, result_type, input=None):
mapping = {"input": input} if input else {} mapping = {"input": input} if input else {}
return ExtractOpBuilder(low_bit, result_type, mapping) return ExtractOpBuilder(low_bit,
result_type,
mapping,
needs_result_type=True)
@UnaryOp @UnaryOp
@ -113,9 +101,12 @@ class ParityOp:
pass pass
@UnaryOp
class SExtOp: class SExtOp:
pass
@classmethod
def create(cls, input=None, result_type=None):
mapping = {"input": input} if input else {}
return UnaryOpBuilder(cls, result_type, mapping, needs_result_type=True)
# Sugar classes for the various non-variadic binary ops. # Sugar classes for the various non-variadic binary ops.
@ -185,14 +176,12 @@ class XorOp:
pass pass
# Sugar classes for miscellaneous ops. @VariadicOp
class ConcatOp: class ConcatOp:
pass
@classmethod
def create(cls, result_type, *args, **kwargs):
return cls(result_type, args, **kwargs)
# Sugar classes for miscellaneous ops.
@CreatableOp @CreatableOp
class MuxOp: class MuxOp:
pass pass

View File

@ -78,6 +78,7 @@ class InstanceBuilder(support.NamedValueOpView):
input_port_mapping, input_port_mapping,
pre_args, pre_args,
post_args, post_args,
needs_result_type=True,
loc=loc, loc=loc,
ip=ip) ip=ip)
@ -325,7 +326,7 @@ class ConstantOp:
@staticmethod @staticmethod
def create(data_type, value): def create(data_type, value):
return hw.ConstantOp(data_type, IntegerAttr.get(data_type, value)) return hw.ConstantOp(IntegerAttr.get(data_type, value))
class ArrayGetOp: class ArrayGetOp:

View File

@ -69,4 +69,5 @@ class CompRegOp:
kwargs, kwargs,
reset=reset, reset=reset,
reset_value=reset_value, reset_value=reset_value,
name=name) name=name,
needs_result_type=True)

View File

@ -32,12 +32,15 @@ def CompareOp(predicate):
@staticmethod @staticmethod
def create(lhs=None, rhs=None): def create(lhs=None, rhs=None):
result_type = IntegerType.get_signless(1)
mapping = {} mapping = {}
if lhs: if lhs:
mapping["lhs"] = lhs mapping["lhs"] = lhs
if rhs: if rhs:
mapping["rhs"] = rhs mapping["rhs"] = rhs
if len(mapping) == 0:
result_type = IntegerType.get_signless(1)
else:
result_type = None
return ICmpOpBuilder(predicate, result_type, mapping) return ICmpOpBuilder(predicate, result_type, mapping)
return _Class return _Class

View File

@ -19,7 +19,7 @@ def reg(value, clock, reset=None, reset_value=None, name=None):
if reset: if reset:
if not reset_value: if not reset_value:
zero = IntegerAttr.get(value_type, 0) zero = IntegerAttr.get(value_type, 0)
reset_value = hw.ConstantOp(value_type, zero).result reset_value = hw.ConstantOp(zero).result
return CompRegOp.create(value_type, return CompRegOp.create(value_type,
input=value, input=value,
clk=clock, clk=clock,

View File

@ -269,11 +269,20 @@ class NamedValueOpView:
def __init__(self, def __init__(self,
cls, cls,
data_type, data_type=None,
input_port_mapping={}, input_port_mapping=None,
pre_args=[], pre_args=None,
post_args=[], post_args=None,
needs_result_type=False,
**kwargs): **kwargs):
# Set defaults
if input_port_mapping is None:
input_port_mapping = {}
if pre_args is None:
pre_args = []
if post_args is None:
post_args = []
# Set result_indices to name each result. # Set result_indices to name each result.
result_names = self.result_names() result_names = self.result_names()
result_indices = {} result_indices = {}
@ -302,8 +311,16 @@ class NamedValueOpView:
if isinstance(data_type, list): if isinstance(data_type, list):
operand_values = [operand_values] operand_values = [operand_values]
self.opview = cls(data_type, *pre_args, *operand_values, *post_args, # In many cases, result types are inferred, and we do not need to pass
**kwargs) # data_type to the underlying constructor. It must be provided to
# NamedValueOpView in cases where we need to build backedges, but should
# generally not be passed to the underlying constructor in this case. There
# are some oddball ops that must pass it, even when building backedges, and
# these set needs_result_type=True.
if data_type is not None and (needs_result_type or len(backedges) == 0):
pre_args.insert(0, data_type)
self.opview = cls(*pre_args, *operand_values, *post_args, **kwargs)
self.operand_indices = operand_indices self.operand_indices = operand_indices
self.result_indices = result_indices self.result_indices = result_indices
self.backedges = backedges self.backedges = backedges