From 3fbedb5ecf4efae3ebdf4e7c9df0e23bca3515b5 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Thu, 4 Nov 2021 16:13:58 -0600 Subject: [PATCH] [Python] Update bindings to work with new result type inference. Python bindings now attempt to perform result type inference, so we no longer have to. This is great in general, but we need to complicate our logic to support the NamedValueOpView functionality. In cases where all operands are known a-priori, we defer to ODS-generated result type inference when possible. In these cases, the result type is removed from user-facing APIs, and may require downstream updates to remove the now-redundant result type at call sites. In cases where not all operands are known and we are building backedges, we take care to ensure the ODS-generated result type inference is used. Ops that may build backedges and do not have result type inference now pass a special flag needs_result_type. Closes https://github.com/llvm/circt/issues/1995. --- .../Bindings/Python/dialects/comb.py | 10 ++--- .../Bindings/Python/dialects/comb_errors.py | 18 ++++----- .../Bindings/Python/dialects/hw.py | 2 +- .../Python/circt/dialects/_comb_ops_ext.py | 37 +++++++------------ .../Python/circt/dialects/_hw_ops_ext.py | 3 +- .../Python/circt/dialects/_seq_ops_ext.py | 3 +- lib/Bindings/Python/circt/dialects/comb.py | 5 ++- lib/Bindings/Python/circt/dialects/seq.py | 2 +- lib/Bindings/Python/circt/support.py | 29 ++++++++++++--- 9 files changed, 59 insertions(+), 50 deletions(-) diff --git a/integration_test/Bindings/Python/dialects/comb.py b/integration_test/Bindings/Python/dialects/comb.py index f536217c2a..c1313fc9e8 100644 --- a/integration_test/Bindings/Python/dialects/comb.py +++ b/integration_test/Bindings/Python/dialects/comb.py @@ -18,10 +18,10 @@ with Context() as ctx, Location.unknown(): def build(module): # CHECK: %[[CONST:.+]] = hw.constant 1 : i32 - const = hw.ConstantOp(i32, IntegerAttr.get(i32, 1)) + const = hw.ConstantOp(IntegerAttr.get(i32, 1)) # CHECK: %[[BIT:.+]] = hw.constant true - bit = hw.ConstantOp(i1, IntegerAttr.get(i1, 1)) + bit = hw.ConstantOp(IntegerAttr.get(i1, 1)) # CHECK: comb.extract %[[CONST]] from 14 comb.ExtractOp.create(14, i32, const.result) @@ -30,7 +30,7 @@ with Context() as ctx, Location.unknown(): connect(extract.input, const.result) # CHECK: comb.parity %[[CONST]] - comb.ParityOp.create(const.result, result_type=i32) + comb.ParityOp.create(const.result) # CHECK: comb.parity %[[CONST]] parity = comb.ParityOp.create(result_type=i32) connect(parity.input, const.result) @@ -173,10 +173,10 @@ with Context() as ctx, Location.unknown(): comb.XorOp.create(const.result, const.result) # CHECK: comb.concat %[[CONST]], %[[CONST]] - comb.ConcatOp.create(i32, const.result, const.result) + comb.ConcatOp.create(const.result, const.result) # CHECK: comb.mux %[[BIT]], %[[CONST]], %[[CONST]] - comb.MuxOp.create(i32, bit.result, const.result, const.result) + comb.MuxOp.create(bit.result, const.result, const.result) hw.HWModuleOp(name="test", body_builder=build) diff --git a/integration_test/Bindings/Python/dialects/comb_errors.py b/integration_test/Bindings/Python/dialects/comb_errors.py index 35f9cfbd38..fc43359cff 100644 --- a/integration_test/Bindings/Python/dialects/comb_errors.py +++ b/integration_test/Bindings/Python/dialects/comb_errors.py @@ -1,5 +1,5 @@ # REQUIRES: bindings_python -# RUN: %PYTHON% %s | FileCheck %s +# RUN: %PYTHON% %s 2>&1 | FileCheck %s import circt from circt.dialects import comb, hw @@ -16,19 +16,17 @@ with Context() as ctx, Location.unknown(): with InsertionPoint(m.body): def build(module): - const1 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1)) - const2 = hw.ConstantOp(i31, IntegerAttr.get(i31, 1)) + const1 = hw.ConstantOp(IntegerAttr.get(i32, 1)) + const2 = hw.ConstantOp(IntegerAttr.get(i31, 1)) - # CHECK: expected same input port types, but received [Type(i32), Type(i31)] - try: - comb.DivSOp.create(const1.result, const2.result) - except TypeError as e: - print(e) + # CHECK: op requires all operands to have the same type + div = comb.DivSOp.create(const1.result, const2.result) + div.opview.verify() - # CHECK: result type must be specified + # CHECK: result type cannot be None try: comb.DivSOp.create() - except TypeError as e: + except ValueError as e: print(e) hw.HWModuleOp(name="test", body_builder=build) diff --git a/integration_test/Bindings/Python/dialects/hw.py b/integration_test/Bindings/Python/dialects/hw.py index 0c622dde30..a88ce3b7d3 100644 --- a/integration_test/Bindings/Python/dialects/hw.py +++ b/integration_test/Bindings/Python/dialects/hw.py @@ -17,7 +17,7 @@ with Context() as ctx, Location.unknown(): with InsertionPoint(m.body): def build(module): - constI32 = hw.ConstantOp(i32, IntegerAttr.get(i32, 1)) + constI32 = hw.ConstantOp(IntegerAttr.get(i32, 1)) constI1 = hw.ConstantOp.create(i1, 1) # CHECK: All arguments must be the same type to create an array diff --git a/lib/Bindings/Python/circt/dialects/_comb_ops_ext.py b/lib/Bindings/Python/circt/dialects/_comb_ops_ext.py index a7042fad75..b3431bd92e 100644 --- a/lib/Bindings/Python/circt/dialects/_comb_ops_ext.py +++ b/lib/Bindings/Python/circt/dialects/_comb_ops_ext.py @@ -4,16 +4,6 @@ from circt.support import NamedValueOpView, get_value from mlir.ir import IntegerAttr, IntegerType, Type -def infer_result_type(operands): - types = list(map(lambda arg: get_value(arg).type, operands)) - if not types: - raise TypeError("result type must be specified") - all_equal = all(type == types[0] for type in types) - if not all_equal: - raise TypeError(f"expected same input port types, but received {types}") - return types[0] - - # Builder base classes for non-variadic unary and binary ops. class UnaryOpBuilder(NamedValueOpView): @@ -59,10 +49,6 @@ def BinaryOp(base): @classmethod def create(cls, lhs=None, rhs=None, result_type=None): - if not result_type: - if not lhs and not rhs: - raise TypeError("result type must be specified") - result_type = infer_result_type([lhs, rhs]) mapping = {} if lhs: mapping["lhs"] = lhs @@ -80,8 +66,7 @@ def VariadicOp(base): @classmethod def create(cls, *args): - result_type = infer_result_type(args) - return cls(result_type, [get_value(a) for a in args]) + return cls([get_value(a) for a in args]) return _Class @@ -105,7 +90,10 @@ class ExtractOp: @staticmethod def create(low_bit, result_type, input=None): mapping = {"input": input} if input else {} - return ExtractOpBuilder(low_bit, result_type, mapping) + return ExtractOpBuilder(low_bit, + result_type, + mapping, + needs_result_type=True) @UnaryOp @@ -113,9 +101,12 @@ class ParityOp: pass -@UnaryOp class SExtOp: - pass + + @classmethod + def create(cls, input=None, result_type=None): + mapping = {"input": input} if input else {} + return UnaryOpBuilder(cls, result_type, mapping, needs_result_type=True) # Sugar classes for the various non-variadic binary ops. @@ -185,14 +176,12 @@ class XorOp: pass -# Sugar classes for miscellaneous ops. +@VariadicOp class ConcatOp: - - @classmethod - def create(cls, result_type, *args, **kwargs): - return cls(result_type, args, **kwargs) + pass +# Sugar classes for miscellaneous ops. @CreatableOp class MuxOp: pass diff --git a/lib/Bindings/Python/circt/dialects/_hw_ops_ext.py b/lib/Bindings/Python/circt/dialects/_hw_ops_ext.py index 3891e71578..83544f1834 100644 --- a/lib/Bindings/Python/circt/dialects/_hw_ops_ext.py +++ b/lib/Bindings/Python/circt/dialects/_hw_ops_ext.py @@ -78,6 +78,7 @@ class InstanceBuilder(support.NamedValueOpView): input_port_mapping, pre_args, post_args, + needs_result_type=True, loc=loc, ip=ip) @@ -325,7 +326,7 @@ class ConstantOp: @staticmethod def create(data_type, value): - return hw.ConstantOp(data_type, IntegerAttr.get(data_type, value)) + return hw.ConstantOp(IntegerAttr.get(data_type, value)) class ArrayGetOp: diff --git a/lib/Bindings/Python/circt/dialects/_seq_ops_ext.py b/lib/Bindings/Python/circt/dialects/_seq_ops_ext.py index 37cc6a011b..7718033416 100644 --- a/lib/Bindings/Python/circt/dialects/_seq_ops_ext.py +++ b/lib/Bindings/Python/circt/dialects/_seq_ops_ext.py @@ -69,4 +69,5 @@ class CompRegOp: kwargs, reset=reset, reset_value=reset_value, - name=name) + name=name, + needs_result_type=True) diff --git a/lib/Bindings/Python/circt/dialects/comb.py b/lib/Bindings/Python/circt/dialects/comb.py index d8c3db191c..48f860c520 100644 --- a/lib/Bindings/Python/circt/dialects/comb.py +++ b/lib/Bindings/Python/circt/dialects/comb.py @@ -32,12 +32,15 @@ def CompareOp(predicate): @staticmethod def create(lhs=None, rhs=None): - result_type = IntegerType.get_signless(1) mapping = {} if lhs: mapping["lhs"] = lhs if rhs: mapping["rhs"] = rhs + if len(mapping) == 0: + result_type = IntegerType.get_signless(1) + else: + result_type = None return ICmpOpBuilder(predicate, result_type, mapping) return _Class diff --git a/lib/Bindings/Python/circt/dialects/seq.py b/lib/Bindings/Python/circt/dialects/seq.py index df8dcb45b3..efd5305587 100644 --- a/lib/Bindings/Python/circt/dialects/seq.py +++ b/lib/Bindings/Python/circt/dialects/seq.py @@ -19,7 +19,7 @@ def reg(value, clock, reset=None, reset_value=None, name=None): if reset: if not reset_value: zero = IntegerAttr.get(value_type, 0) - reset_value = hw.ConstantOp(value_type, zero).result + reset_value = hw.ConstantOp(zero).result return CompRegOp.create(value_type, input=value, clk=clock, diff --git a/lib/Bindings/Python/circt/support.py b/lib/Bindings/Python/circt/support.py index 3951d3fed8..8f97326aa6 100644 --- a/lib/Bindings/Python/circt/support.py +++ b/lib/Bindings/Python/circt/support.py @@ -269,11 +269,20 @@ class NamedValueOpView: def __init__(self, cls, - data_type, - input_port_mapping={}, - pre_args=[], - post_args=[], + data_type=None, + input_port_mapping=None, + pre_args=None, + post_args=None, + needs_result_type=False, **kwargs): + # Set defaults + if input_port_mapping is None: + input_port_mapping = {} + if pre_args is None: + pre_args = [] + if post_args is None: + post_args = [] + # Set result_indices to name each result. result_names = self.result_names() result_indices = {} @@ -302,8 +311,16 @@ class NamedValueOpView: if isinstance(data_type, list): operand_values = [operand_values] - self.opview = cls(data_type, *pre_args, *operand_values, *post_args, - **kwargs) + # In many cases, result types are inferred, and we do not need to pass + # data_type to the underlying constructor. It must be provided to + # NamedValueOpView in cases where we need to build backedges, but should + # generally not be passed to the underlying constructor in this case. There + # are some oddball ops that must pass it, even when building backedges, and + # these set needs_result_type=True. + if data_type is not None and (needs_result_type or len(backedges) == 0): + pre_args.insert(0, data_type) + + self.opview = cls(*pre_args, *operand_values, *post_args, **kwargs) self.operand_indices = operand_indices self.result_indices = result_indices self.backedges = backedges