mirror of https://github.com/llvm/circt.git
[Python] Update bindings to work with new result type inference.
Python bindings now attempt to perform result type inference, so we no longer have to. This is great in general, but we need to complicate our logic to support the NamedValueOpView functionality. In cases where all operands are known a-priori, we defer to ODS-generated result type inference when possible. In these cases, the result type is removed from user-facing APIs, and may require downstream updates to remove the now-redundant result type at call sites. In cases where not all operands are known and we are building backedges, we take care to ensure the ODS-generated result type inference is used. Ops that may build backedges and do not have result type inference now pass a special flag needs_result_type. Closes https://github.com/llvm/circt/issues/1995.
This commit is contained in:
parent
08ed4ff948
commit
3fbedb5ecf
|
@ -18,10 +18,10 @@ with Context() as ctx, Location.unknown():
|
||||||
|
|
||||||
def build(module):
|
def build(module):
|
||||||
# CHECK: %[[CONST:.+]] = hw.constant 1 : i32
|
# CHECK: %[[CONST:.+]] = hw.constant 1 : i32
|
||||||
const = hw.ConstantOp(i32, IntegerAttr.get(i32, 1))
|
const = hw.ConstantOp(IntegerAttr.get(i32, 1))
|
||||||
|
|
||||||
# CHECK: %[[BIT:.+]] = hw.constant true
|
# CHECK: %[[BIT:.+]] = hw.constant true
|
||||||
bit = hw.ConstantOp(i1, IntegerAttr.get(i1, 1))
|
bit = hw.ConstantOp(IntegerAttr.get(i1, 1))
|
||||||
|
|
||||||
# CHECK: comb.extract %[[CONST]] from 14
|
# CHECK: comb.extract %[[CONST]] from 14
|
||||||
comb.ExtractOp.create(14, i32, const.result)
|
comb.ExtractOp.create(14, i32, const.result)
|
||||||
|
@ -30,7 +30,7 @@ with Context() as ctx, Location.unknown():
|
||||||
connect(extract.input, const.result)
|
connect(extract.input, const.result)
|
||||||
|
|
||||||
# CHECK: comb.parity %[[CONST]]
|
# CHECK: comb.parity %[[CONST]]
|
||||||
comb.ParityOp.create(const.result, result_type=i32)
|
comb.ParityOp.create(const.result)
|
||||||
# CHECK: comb.parity %[[CONST]]
|
# CHECK: comb.parity %[[CONST]]
|
||||||
parity = comb.ParityOp.create(result_type=i32)
|
parity = comb.ParityOp.create(result_type=i32)
|
||||||
connect(parity.input, const.result)
|
connect(parity.input, const.result)
|
||||||
|
@ -173,10 +173,10 @@ with Context() as ctx, Location.unknown():
|
||||||
comb.XorOp.create(const.result, const.result)
|
comb.XorOp.create(const.result, const.result)
|
||||||
|
|
||||||
# CHECK: comb.concat %[[CONST]], %[[CONST]]
|
# CHECK: comb.concat %[[CONST]], %[[CONST]]
|
||||||
comb.ConcatOp.create(i32, const.result, const.result)
|
comb.ConcatOp.create(const.result, const.result)
|
||||||
|
|
||||||
# CHECK: comb.mux %[[BIT]], %[[CONST]], %[[CONST]]
|
# CHECK: comb.mux %[[BIT]], %[[CONST]], %[[CONST]]
|
||||||
comb.MuxOp.create(i32, bit.result, const.result, const.result)
|
comb.MuxOp.create(bit.result, const.result, const.result)
|
||||||
|
|
||||||
hw.HWModuleOp(name="test", body_builder=build)
|
hw.HWModuleOp(name="test", body_builder=build)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# REQUIRES: bindings_python
|
# REQUIRES: bindings_python
|
||||||
# RUN: %PYTHON% %s | FileCheck %s
|
# RUN: %PYTHON% %s 2>&1 | FileCheck %s
|
||||||
|
|
||||||
import circt
|
import circt
|
||||||
from circt.dialects import comb, hw
|
from circt.dialects import comb, hw
|
||||||
|
@ -16,19 +16,17 @@ with Context() as ctx, Location.unknown():
|
||||||
with InsertionPoint(m.body):
|
with InsertionPoint(m.body):
|
||||||
|
|
||||||
def build(module):
|
def build(module):
|
||||||
const1 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1))
|
const1 = hw.ConstantOp(IntegerAttr.get(i32, 1))
|
||||||
const2 = hw.ConstantOp(i31, IntegerAttr.get(i31, 1))
|
const2 = hw.ConstantOp(IntegerAttr.get(i31, 1))
|
||||||
|
|
||||||
# CHECK: expected same input port types, but received [Type(i32), Type(i31)]
|
# CHECK: op requires all operands to have the same type
|
||||||
try:
|
div = comb.DivSOp.create(const1.result, const2.result)
|
||||||
comb.DivSOp.create(const1.result, const2.result)
|
div.opview.verify()
|
||||||
except TypeError as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
# CHECK: result type must be specified
|
# CHECK: result type cannot be None
|
||||||
try:
|
try:
|
||||||
comb.DivSOp.create()
|
comb.DivSOp.create()
|
||||||
except TypeError as e:
|
except ValueError as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
hw.HWModuleOp(name="test", body_builder=build)
|
hw.HWModuleOp(name="test", body_builder=build)
|
||||||
|
|
|
@ -17,7 +17,7 @@ with Context() as ctx, Location.unknown():
|
||||||
with InsertionPoint(m.body):
|
with InsertionPoint(m.body):
|
||||||
|
|
||||||
def build(module):
|
def build(module):
|
||||||
constI32 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1))
|
constI32 = hw.ConstantOp(IntegerAttr.get(i32, 1))
|
||||||
constI1 = hw.ConstantOp.create(i1, 1)
|
constI1 = hw.ConstantOp.create(i1, 1)
|
||||||
|
|
||||||
# CHECK: All arguments must be the same type to create an array
|
# CHECK: All arguments must be the same type to create an array
|
||||||
|
|
|
@ -4,16 +4,6 @@ from circt.support import NamedValueOpView, get_value
|
||||||
from mlir.ir import IntegerAttr, IntegerType, Type
|
from mlir.ir import IntegerAttr, IntegerType, Type
|
||||||
|
|
||||||
|
|
||||||
def infer_result_type(operands):
|
|
||||||
types = list(map(lambda arg: get_value(arg).type, operands))
|
|
||||||
if not types:
|
|
||||||
raise TypeError("result type must be specified")
|
|
||||||
all_equal = all(type == types[0] for type in types)
|
|
||||||
if not all_equal:
|
|
||||||
raise TypeError(f"expected same input port types, but received {types}")
|
|
||||||
return types[0]
|
|
||||||
|
|
||||||
|
|
||||||
# Builder base classes for non-variadic unary and binary ops.
|
# Builder base classes for non-variadic unary and binary ops.
|
||||||
class UnaryOpBuilder(NamedValueOpView):
|
class UnaryOpBuilder(NamedValueOpView):
|
||||||
|
|
||||||
|
@ -59,10 +49,6 @@ def BinaryOp(base):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, lhs=None, rhs=None, result_type=None):
|
def create(cls, lhs=None, rhs=None, result_type=None):
|
||||||
if not result_type:
|
|
||||||
if not lhs and not rhs:
|
|
||||||
raise TypeError("result type must be specified")
|
|
||||||
result_type = infer_result_type([lhs, rhs])
|
|
||||||
mapping = {}
|
mapping = {}
|
||||||
if lhs:
|
if lhs:
|
||||||
mapping["lhs"] = lhs
|
mapping["lhs"] = lhs
|
||||||
|
@ -80,8 +66,7 @@ def VariadicOp(base):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, *args):
|
def create(cls, *args):
|
||||||
result_type = infer_result_type(args)
|
return cls([get_value(a) for a in args])
|
||||||
return cls(result_type, [get_value(a) for a in args])
|
|
||||||
|
|
||||||
return _Class
|
return _Class
|
||||||
|
|
||||||
|
@ -105,7 +90,10 @@ class ExtractOp:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(low_bit, result_type, input=None):
|
def create(low_bit, result_type, input=None):
|
||||||
mapping = {"input": input} if input else {}
|
mapping = {"input": input} if input else {}
|
||||||
return ExtractOpBuilder(low_bit, result_type, mapping)
|
return ExtractOpBuilder(low_bit,
|
||||||
|
result_type,
|
||||||
|
mapping,
|
||||||
|
needs_result_type=True)
|
||||||
|
|
||||||
|
|
||||||
@UnaryOp
|
@UnaryOp
|
||||||
|
@ -113,9 +101,12 @@ class ParityOp:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@UnaryOp
|
|
||||||
class SExtOp:
|
class SExtOp:
|
||||||
pass
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, input=None, result_type=None):
|
||||||
|
mapping = {"input": input} if input else {}
|
||||||
|
return UnaryOpBuilder(cls, result_type, mapping, needs_result_type=True)
|
||||||
|
|
||||||
|
|
||||||
# Sugar classes for the various non-variadic binary ops.
|
# Sugar classes for the various non-variadic binary ops.
|
||||||
|
@ -185,14 +176,12 @@ class XorOp:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Sugar classes for miscellaneous ops.
|
@VariadicOp
|
||||||
class ConcatOp:
|
class ConcatOp:
|
||||||
|
pass
|
||||||
@classmethod
|
|
||||||
def create(cls, result_type, *args, **kwargs):
|
|
||||||
return cls(result_type, args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
|
# Sugar classes for miscellaneous ops.
|
||||||
@CreatableOp
|
@CreatableOp
|
||||||
class MuxOp:
|
class MuxOp:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -78,6 +78,7 @@ class InstanceBuilder(support.NamedValueOpView):
|
||||||
input_port_mapping,
|
input_port_mapping,
|
||||||
pre_args,
|
pre_args,
|
||||||
post_args,
|
post_args,
|
||||||
|
needs_result_type=True,
|
||||||
loc=loc,
|
loc=loc,
|
||||||
ip=ip)
|
ip=ip)
|
||||||
|
|
||||||
|
@ -325,7 +326,7 @@ class ConstantOp:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(data_type, value):
|
def create(data_type, value):
|
||||||
return hw.ConstantOp(data_type, IntegerAttr.get(data_type, value))
|
return hw.ConstantOp(IntegerAttr.get(data_type, value))
|
||||||
|
|
||||||
|
|
||||||
class ArrayGetOp:
|
class ArrayGetOp:
|
||||||
|
|
|
@ -69,4 +69,5 @@ class CompRegOp:
|
||||||
kwargs,
|
kwargs,
|
||||||
reset=reset,
|
reset=reset,
|
||||||
reset_value=reset_value,
|
reset_value=reset_value,
|
||||||
name=name)
|
name=name,
|
||||||
|
needs_result_type=True)
|
||||||
|
|
|
@ -32,12 +32,15 @@ def CompareOp(predicate):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(lhs=None, rhs=None):
|
def create(lhs=None, rhs=None):
|
||||||
result_type = IntegerType.get_signless(1)
|
|
||||||
mapping = {}
|
mapping = {}
|
||||||
if lhs:
|
if lhs:
|
||||||
mapping["lhs"] = lhs
|
mapping["lhs"] = lhs
|
||||||
if rhs:
|
if rhs:
|
||||||
mapping["rhs"] = rhs
|
mapping["rhs"] = rhs
|
||||||
|
if len(mapping) == 0:
|
||||||
|
result_type = IntegerType.get_signless(1)
|
||||||
|
else:
|
||||||
|
result_type = None
|
||||||
return ICmpOpBuilder(predicate, result_type, mapping)
|
return ICmpOpBuilder(predicate, result_type, mapping)
|
||||||
|
|
||||||
return _Class
|
return _Class
|
||||||
|
|
|
@ -19,7 +19,7 @@ def reg(value, clock, reset=None, reset_value=None, name=None):
|
||||||
if reset:
|
if reset:
|
||||||
if not reset_value:
|
if not reset_value:
|
||||||
zero = IntegerAttr.get(value_type, 0)
|
zero = IntegerAttr.get(value_type, 0)
|
||||||
reset_value = hw.ConstantOp(value_type, zero).result
|
reset_value = hw.ConstantOp(zero).result
|
||||||
return CompRegOp.create(value_type,
|
return CompRegOp.create(value_type,
|
||||||
input=value,
|
input=value,
|
||||||
clk=clock,
|
clk=clock,
|
||||||
|
|
|
@ -269,11 +269,20 @@ class NamedValueOpView:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
cls,
|
cls,
|
||||||
data_type,
|
data_type=None,
|
||||||
input_port_mapping={},
|
input_port_mapping=None,
|
||||||
pre_args=[],
|
pre_args=None,
|
||||||
post_args=[],
|
post_args=None,
|
||||||
|
needs_result_type=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
# Set defaults
|
||||||
|
if input_port_mapping is None:
|
||||||
|
input_port_mapping = {}
|
||||||
|
if pre_args is None:
|
||||||
|
pre_args = []
|
||||||
|
if post_args is None:
|
||||||
|
post_args = []
|
||||||
|
|
||||||
# Set result_indices to name each result.
|
# Set result_indices to name each result.
|
||||||
result_names = self.result_names()
|
result_names = self.result_names()
|
||||||
result_indices = {}
|
result_indices = {}
|
||||||
|
@ -302,8 +311,16 @@ class NamedValueOpView:
|
||||||
if isinstance(data_type, list):
|
if isinstance(data_type, list):
|
||||||
operand_values = [operand_values]
|
operand_values = [operand_values]
|
||||||
|
|
||||||
self.opview = cls(data_type, *pre_args, *operand_values, *post_args,
|
# In many cases, result types are inferred, and we do not need to pass
|
||||||
**kwargs)
|
# data_type to the underlying constructor. It must be provided to
|
||||||
|
# NamedValueOpView in cases where we need to build backedges, but should
|
||||||
|
# generally not be passed to the underlying constructor in this case. There
|
||||||
|
# are some oddball ops that must pass it, even when building backedges, and
|
||||||
|
# these set needs_result_type=True.
|
||||||
|
if data_type is not None and (needs_result_type or len(backedges) == 0):
|
||||||
|
pre_args.insert(0, data_type)
|
||||||
|
|
||||||
|
self.opview = cls(*pre_args, *operand_values, *post_args, **kwargs)
|
||||||
self.operand_indices = operand_indices
|
self.operand_indices = operand_indices
|
||||||
self.result_indices = result_indices
|
self.result_indices = result_indices
|
||||||
self.backedges = backedges
|
self.backedges = backedges
|
||||||
|
|
Loading…
Reference in New Issue