[Python][OM] Handle BoolAttr's before IntegerAttr's. (#7438)

BoolAttr's are IntegerAttr's, check them first.

IntegerAttr's that happen to have the characteristics of
BoolAttr will accordingly become Python boolean values.

Unclear where these come from but we do lower booleans
to MLIR bool constants so make sure to handle that.

Add test for object model IR with bool constants.
This commit is contained in:
Will Dietz 2024-08-05 12:21:19 -05:00 committed by GitHub
parent cbdee94d96
commit bec0deab4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 10 deletions

View File

@ -68,6 +68,11 @@ with Context() as ctx, Location.unknown():
%map = om.map_create %entry1, %entry2: !om.string, !om.integer
om.class.field @map_create, %map : !om.map<!om.string, !om.integer>
%true = om.constant true
om.class.field @true, %true : i1
%false = om.constant false
om.class.field @false, %false : i1
}
om.class @Child(%0: !om.integer) {
@ -157,7 +162,7 @@ print(obj.get_field_loc("field"))
# CHECK: 14
print(obj.child.foo)
# CHECK: loc("-":60:7)
# CHECK: loc("-":65:7)
print(obj.child.get_field_loc("foo"))
# CHECK: ('Root', 'x')
print(obj.reference)
@ -224,6 +229,11 @@ for k, v in obj.map_create.items():
# CHECK-NEXT: Y 15
print(k, v)
# CHECK: True
print(obj.true)
# CHECK: False
print(obj.false)
obj = evaluator.instantiate("Client")
object_dict: Dict[om.Object, str] = {}
for field_name, data in obj:

View File

@ -366,15 +366,6 @@ Map::dunderGetItem(std::variant<intptr_t, std::string, MlirAttribute> key) {
// Convert a generic MLIR Attribute to a PythonValue. This is basically a C++
// fast path of the parts of attribute_to_var that we use in the OM dialect.
static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
if (mlirAttributeIsAInteger(attr)) {
MlirType type = mlirAttributeGetType(attr);
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
return py::int_(mlirIntegerAttrGetValueInt(attr));
if (mlirIntegerTypeIsSigned(type))
return py::int_(mlirIntegerAttrGetValueSInt(attr));
return py::int_(mlirIntegerAttrGetValueUInt(attr));
}
if (omAttrIsAIntegerAttr(attr)) {
auto strRef = omIntegerAttrToString(attr);
return py::int_(py::str(strRef.data, strRef.length));
@ -389,10 +380,20 @@ static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
return py::str(strRef.data, strRef.length);
}
// BoolAttr's are IntegerAttr's, check this first.
if (mlirAttributeIsABool(attr)) {
return py::bool_(mlirBoolAttrGetValue(attr));
}
if (mlirAttributeIsAInteger(attr)) {
MlirType type = mlirAttributeGetType(attr);
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
return py::int_(mlirIntegerAttrGetValueInt(attr));
if (mlirIntegerTypeIsSigned(type))
return py::int_(mlirIntegerAttrGetValueSInt(attr));
return py::int_(mlirIntegerAttrGetValueUInt(attr));
}
if (omAttrIsAReferenceAttr(attr)) {
auto innerRef = omReferenceAttrGetInnerRef(attr);
auto moduleStrRef =