mirror of https://github.com/llvm/circt.git
[Python] Update bindings to work with new result type inference.
Python bindings now attempt to perform result type inference, so we no longer have to. This is great in general, but we need to complicate our logic to support the NamedValueOpView functionality. In cases where all operands are known a-priori, we defer to ODS-generated result type inference when possible. In these cases, the result type is removed from user-facing APIs, and may require downstream updates to remove the now-redundant result type at call sites. In cases where not all operands are known and we are building backedges, we take care to ensure the ODS-generated result type inference is used. Ops that may build backedges and do not have result type inference now pass a special flag needs_result_type. Closes https://github.com/llvm/circt/issues/1995.
This commit is contained in:
parent
08ed4ff948
commit
3fbedb5ecf
|
@ -18,10 +18,10 @@ with Context() as ctx, Location.unknown():
|
|||
|
||||
def build(module):
|
||||
# CHECK: %[[CONST:.+]] = hw.constant 1 : i32
|
||||
const = hw.ConstantOp(i32, IntegerAttr.get(i32, 1))
|
||||
const = hw.ConstantOp(IntegerAttr.get(i32, 1))
|
||||
|
||||
# CHECK: %[[BIT:.+]] = hw.constant true
|
||||
bit = hw.ConstantOp(i1, IntegerAttr.get(i1, 1))
|
||||
bit = hw.ConstantOp(IntegerAttr.get(i1, 1))
|
||||
|
||||
# CHECK: comb.extract %[[CONST]] from 14
|
||||
comb.ExtractOp.create(14, i32, const.result)
|
||||
|
@ -30,7 +30,7 @@ with Context() as ctx, Location.unknown():
|
|||
connect(extract.input, const.result)
|
||||
|
||||
# CHECK: comb.parity %[[CONST]]
|
||||
comb.ParityOp.create(const.result, result_type=i32)
|
||||
comb.ParityOp.create(const.result)
|
||||
# CHECK: comb.parity %[[CONST]]
|
||||
parity = comb.ParityOp.create(result_type=i32)
|
||||
connect(parity.input, const.result)
|
||||
|
@ -173,10 +173,10 @@ with Context() as ctx, Location.unknown():
|
|||
comb.XorOp.create(const.result, const.result)
|
||||
|
||||
# CHECK: comb.concat %[[CONST]], %[[CONST]]
|
||||
comb.ConcatOp.create(i32, const.result, const.result)
|
||||
comb.ConcatOp.create(const.result, const.result)
|
||||
|
||||
# CHECK: comb.mux %[[BIT]], %[[CONST]], %[[CONST]]
|
||||
comb.MuxOp.create(i32, bit.result, const.result, const.result)
|
||||
comb.MuxOp.create(bit.result, const.result, const.result)
|
||||
|
||||
hw.HWModuleOp(name="test", body_builder=build)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# REQUIRES: bindings_python
|
||||
# RUN: %PYTHON% %s | FileCheck %s
|
||||
# RUN: %PYTHON% %s 2>&1 | FileCheck %s
|
||||
|
||||
import circt
|
||||
from circt.dialects import comb, hw
|
||||
|
@ -16,19 +16,17 @@ with Context() as ctx, Location.unknown():
|
|||
with InsertionPoint(m.body):
|
||||
|
||||
def build(module):
|
||||
const1 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1))
|
||||
const2 = hw.ConstantOp(i31, IntegerAttr.get(i31, 1))
|
||||
const1 = hw.ConstantOp(IntegerAttr.get(i32, 1))
|
||||
const2 = hw.ConstantOp(IntegerAttr.get(i31, 1))
|
||||
|
||||
# CHECK: expected same input port types, but received [Type(i32), Type(i31)]
|
||||
try:
|
||||
comb.DivSOp.create(const1.result, const2.result)
|
||||
except TypeError as e:
|
||||
print(e)
|
||||
# CHECK: op requires all operands to have the same type
|
||||
div = comb.DivSOp.create(const1.result, const2.result)
|
||||
div.opview.verify()
|
||||
|
||||
# CHECK: result type must be specified
|
||||
# CHECK: result type cannot be None
|
||||
try:
|
||||
comb.DivSOp.create()
|
||||
except TypeError as e:
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
|
||||
hw.HWModuleOp(name="test", body_builder=build)
|
||||
|
|
|
@ -17,7 +17,7 @@ with Context() as ctx, Location.unknown():
|
|||
with InsertionPoint(m.body):
|
||||
|
||||
def build(module):
|
||||
constI32 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1))
|
||||
constI32 = hw.ConstantOp(IntegerAttr.get(i32, 1))
|
||||
constI1 = hw.ConstantOp.create(i1, 1)
|
||||
|
||||
# CHECK: All arguments must be the same type to create an array
|
||||
|
|
|
@ -4,16 +4,6 @@ from circt.support import NamedValueOpView, get_value
|
|||
from mlir.ir import IntegerAttr, IntegerType, Type
|
||||
|
||||
|
||||
def infer_result_type(operands):
|
||||
types = list(map(lambda arg: get_value(arg).type, operands))
|
||||
if not types:
|
||||
raise TypeError("result type must be specified")
|
||||
all_equal = all(type == types[0] for type in types)
|
||||
if not all_equal:
|
||||
raise TypeError(f"expected same input port types, but received {types}")
|
||||
return types[0]
|
||||
|
||||
|
||||
# Builder base classes for non-variadic unary and binary ops.
|
||||
class UnaryOpBuilder(NamedValueOpView):
|
||||
|
||||
|
@ -59,10 +49,6 @@ def BinaryOp(base):
|
|||
|
||||
@classmethod
|
||||
def create(cls, lhs=None, rhs=None, result_type=None):
|
||||
if not result_type:
|
||||
if not lhs and not rhs:
|
||||
raise TypeError("result type must be specified")
|
||||
result_type = infer_result_type([lhs, rhs])
|
||||
mapping = {}
|
||||
if lhs:
|
||||
mapping["lhs"] = lhs
|
||||
|
@ -80,8 +66,7 @@ def VariadicOp(base):
|
|||
|
||||
@classmethod
|
||||
def create(cls, *args):
|
||||
result_type = infer_result_type(args)
|
||||
return cls(result_type, [get_value(a) for a in args])
|
||||
return cls([get_value(a) for a in args])
|
||||
|
||||
return _Class
|
||||
|
||||
|
@ -105,7 +90,10 @@ class ExtractOp:
|
|||
@staticmethod
|
||||
def create(low_bit, result_type, input=None):
|
||||
mapping = {"input": input} if input else {}
|
||||
return ExtractOpBuilder(low_bit, result_type, mapping)
|
||||
return ExtractOpBuilder(low_bit,
|
||||
result_type,
|
||||
mapping,
|
||||
needs_result_type=True)
|
||||
|
||||
|
||||
@UnaryOp
|
||||
|
@ -113,9 +101,12 @@ class ParityOp:
|
|||
pass
|
||||
|
||||
|
||||
@UnaryOp
|
||||
class SExtOp:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def create(cls, input=None, result_type=None):
|
||||
mapping = {"input": input} if input else {}
|
||||
return UnaryOpBuilder(cls, result_type, mapping, needs_result_type=True)
|
||||
|
||||
|
||||
# Sugar classes for the various non-variadic binary ops.
|
||||
|
@ -185,14 +176,12 @@ class XorOp:
|
|||
pass
|
||||
|
||||
|
||||
# Sugar classes for miscellaneous ops.
|
||||
@VariadicOp
|
||||
class ConcatOp:
|
||||
|
||||
@classmethod
|
||||
def create(cls, result_type, *args, **kwargs):
|
||||
return cls(result_type, args, **kwargs)
|
||||
pass
|
||||
|
||||
|
||||
# Sugar classes for miscellaneous ops.
|
||||
@CreatableOp
|
||||
class MuxOp:
|
||||
pass
|
||||
|
|
|
@ -78,6 +78,7 @@ class InstanceBuilder(support.NamedValueOpView):
|
|||
input_port_mapping,
|
||||
pre_args,
|
||||
post_args,
|
||||
needs_result_type=True,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
|
@ -325,7 +326,7 @@ class ConstantOp:
|
|||
|
||||
@staticmethod
|
||||
def create(data_type, value):
|
||||
return hw.ConstantOp(data_type, IntegerAttr.get(data_type, value))
|
||||
return hw.ConstantOp(IntegerAttr.get(data_type, value))
|
||||
|
||||
|
||||
class ArrayGetOp:
|
||||
|
|
|
@ -69,4 +69,5 @@ class CompRegOp:
|
|||
kwargs,
|
||||
reset=reset,
|
||||
reset_value=reset_value,
|
||||
name=name)
|
||||
name=name,
|
||||
needs_result_type=True)
|
||||
|
|
|
@ -32,12 +32,15 @@ def CompareOp(predicate):
|
|||
|
||||
@staticmethod
|
||||
def create(lhs=None, rhs=None):
|
||||
result_type = IntegerType.get_signless(1)
|
||||
mapping = {}
|
||||
if lhs:
|
||||
mapping["lhs"] = lhs
|
||||
if rhs:
|
||||
mapping["rhs"] = rhs
|
||||
if len(mapping) == 0:
|
||||
result_type = IntegerType.get_signless(1)
|
||||
else:
|
||||
result_type = None
|
||||
return ICmpOpBuilder(predicate, result_type, mapping)
|
||||
|
||||
return _Class
|
||||
|
|
|
@ -19,7 +19,7 @@ def reg(value, clock, reset=None, reset_value=None, name=None):
|
|||
if reset:
|
||||
if not reset_value:
|
||||
zero = IntegerAttr.get(value_type, 0)
|
||||
reset_value = hw.ConstantOp(value_type, zero).result
|
||||
reset_value = hw.ConstantOp(zero).result
|
||||
return CompRegOp.create(value_type,
|
||||
input=value,
|
||||
clk=clock,
|
||||
|
|
|
@ -269,11 +269,20 @@ class NamedValueOpView:
|
|||
|
||||
def __init__(self,
|
||||
cls,
|
||||
data_type,
|
||||
input_port_mapping={},
|
||||
pre_args=[],
|
||||
post_args=[],
|
||||
data_type=None,
|
||||
input_port_mapping=None,
|
||||
pre_args=None,
|
||||
post_args=None,
|
||||
needs_result_type=False,
|
||||
**kwargs):
|
||||
# Set defaults
|
||||
if input_port_mapping is None:
|
||||
input_port_mapping = {}
|
||||
if pre_args is None:
|
||||
pre_args = []
|
||||
if post_args is None:
|
||||
post_args = []
|
||||
|
||||
# Set result_indices to name each result.
|
||||
result_names = self.result_names()
|
||||
result_indices = {}
|
||||
|
@ -302,8 +311,16 @@ class NamedValueOpView:
|
|||
if isinstance(data_type, list):
|
||||
operand_values = [operand_values]
|
||||
|
||||
self.opview = cls(data_type, *pre_args, *operand_values, *post_args,
|
||||
**kwargs)
|
||||
# In many cases, result types are inferred, and we do not need to pass
|
||||
# data_type to the underlying constructor. It must be provided to
|
||||
# NamedValueOpView in cases where we need to build backedges, but should
|
||||
# generally not be passed to the underlying constructor in this case. There
|
||||
# are some oddball ops that must pass it, even when building backedges, and
|
||||
# these set needs_result_type=True.
|
||||
if data_type is not None and (needs_result_type or len(backedges) == 0):
|
||||
pre_args.insert(0, data_type)
|
||||
|
||||
self.opview = cls(*pre_args, *operand_values, *post_args, **kwargs)
|
||||
self.operand_indices = operand_indices
|
||||
self.result_indices = result_indices
|
||||
self.backedges = backedges
|
||||
|
|
Loading…
Reference in New Issue