diff --git a/integration_test/Bindings/Python/dialects/om.py b/integration_test/Bindings/Python/dialects/om.py index fd6dcf0030..e7afceb1dd 100644 --- a/integration_test/Bindings/Python/dialects/om.py +++ b/integration_test/Bindings/Python/dialects/om.py @@ -68,6 +68,11 @@ with Context() as ctx, Location.unknown(): %map = om.map_create %entry1, %entry2: !om.string, !om.integer om.class.field @map_create, %map : !om.map + + %true = om.constant true + om.class.field @true, %true : i1 + %false = om.constant false + om.class.field @false, %false : i1 } om.class @Child(%0: !om.integer) { @@ -157,7 +162,7 @@ print(obj.get_field_loc("field")) # CHECK: 14 print(obj.child.foo) -# CHECK: loc("-":60:7) +# CHECK: loc("-":65:7) print(obj.child.get_field_loc("foo")) # CHECK: ('Root', 'x') print(obj.reference) @@ -224,6 +229,11 @@ for k, v in obj.map_create.items(): # CHECK-NEXT: Y 15 print(k, v) +# CHECK: True +print(obj.true) +# CHECK: False +print(obj.false) + obj = evaluator.instantiate("Client") object_dict: Dict[om.Object, str] = {} for field_name, data in obj: diff --git a/lib/Bindings/Python/OMModule.cpp b/lib/Bindings/Python/OMModule.cpp index 2d42186d36..16fa03873e 100644 --- a/lib/Bindings/Python/OMModule.cpp +++ b/lib/Bindings/Python/OMModule.cpp @@ -366,15 +366,6 @@ Map::dunderGetItem(std::variant key) { // Convert a generic MLIR Attribute to a PythonValue. This is basically a C++ // fast path of the parts of attribute_to_var that we use in the OM dialect. static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) { - if (mlirAttributeIsAInteger(attr)) { - MlirType type = mlirAttributeGetType(attr); - if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) - return py::int_(mlirIntegerAttrGetValueInt(attr)); - if (mlirIntegerTypeIsSigned(type)) - return py::int_(mlirIntegerAttrGetValueSInt(attr)); - return py::int_(mlirIntegerAttrGetValueUInt(attr)); - } - if (omAttrIsAIntegerAttr(attr)) { auto strRef = omIntegerAttrToString(attr); return py::int_(py::str(strRef.data, strRef.length)); @@ -389,10 +380,20 @@ static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) { return py::str(strRef.data, strRef.length); } + // BoolAttr's are IntegerAttr's, check this first. if (mlirAttributeIsABool(attr)) { return py::bool_(mlirBoolAttrGetValue(attr)); } + if (mlirAttributeIsAInteger(attr)) { + MlirType type = mlirAttributeGetType(attr); + if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) + return py::int_(mlirIntegerAttrGetValueInt(attr)); + if (mlirIntegerTypeIsSigned(type)) + return py::int_(mlirIntegerAttrGetValueSInt(attr)); + return py::int_(mlirIntegerAttrGetValueUInt(attr)); + } + if (omAttrIsAReferenceAttr(attr)) { auto innerRef = omReferenceAttrGetInnerRef(attr); auto moduleStrRef =