[Py] [PyCDE] Add `get_fields` to struct type and use it in pycde (#1364)

PyCDE uses this for error checking and `.` access.
This commit is contained in:
John Demme 2021-07-05 20:36:54 -07:00 committed by GitHub
parent b57223d21b
commit 2a7505c33f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 6 deletions

View File

@ -24,12 +24,23 @@ class Value:
return Value(hw.ArrayGetOp.create(self.value, idx))
if isinstance(self.type, hw.StructType):
fields = self.type.get_fields()
if sub not in [name for name, _ in fields]:
raise ValueError(f"Struct field '{sub}' not found in {self.type}")
with get_user_loc():
return Value(hw.StructExtractOp.create(self.value, sub))
raise TypeError(
"Subscripting only supported on hw.array and hw.struct types")
def __getattr__(self, attr):
if isinstance(self.type, hw.StructType):
fields = self.type.get_fields()
if attr in [name for name, _ in fields]:
with get_user_loc():
return Value(hw.StructExtractOp.create(self.value, attr))
raise AttributeError(f"'Value' object has no attribute '{attr}'")
# PyCDE needs a custom version of this to support python classes.
def var_to_attribute(obj) -> ir.Attribute:

View File

@ -1,6 +1,15 @@
# RUN: %PYTHON% %s 2>&1 | FileCheck %s
from pycde import dim, types
import sys
# CHECK: [('foo', Type(i1)), ('bar', Type(i13))]
# CHECK: i1
st1 = types.struct({"foo": types.i1, "bar": types.i13})
fields = st1.get_fields()
sys.stderr.write(str(fields) + "\n")
st1.get_field("foo").dump()
print()
# CHECK: i6
array1 = dim(types.i6)

View File

@ -28,6 +28,7 @@ struct HWStructFieldInfo {
MlirStringRef name;
MlirType type;
};
typedef struct HWStructFieldInfo HWStructFieldInfo;
//===----------------------------------------------------------------------===//
// Dialect API.
@ -80,13 +81,17 @@ MLIR_CAPI_EXPORTED MlirType hwInOutTypeGet(MlirType element);
MLIR_CAPI_EXPORTED MlirType hwInOutTypeGetElementType(MlirType);
/// Creates an HW struct type in the context associated with the elements.
MLIR_CAPI_EXPORTED MlirType
hwStructTypeGet(MlirContext ctx, intptr_t numElements,
struct HWStructFieldInfo const *elements);
MLIR_CAPI_EXPORTED MlirType hwStructTypeGet(MlirContext ctx,
intptr_t numElements,
HWStructFieldInfo const *elements);
MLIR_CAPI_EXPORTED MlirType hwStructTypeGetField(MlirType structType,
MlirStringRef fieldName);
MLIR_CAPI_EXPORTED HWStructFieldInfo
hwStructTypeGetFieldNum(MlirType structType, unsigned idx);
MLIR_CAPI_EXPORTED intptr_t hwStructTypeGetNumFields(MlirType structType);
MLIR_CAPI_EXPORTED MlirType hwTypeAliasTypeGet(MlirStringRef scope,
MlirStringRef name,
MlirType innerType);

View File

@ -61,9 +61,20 @@ void circt::python::populateDialectHWSubmodule(py::module &m) {
return cls(hwStructTypeGet(ctx, mlirFieldInfos.size(),
mlirFieldInfos.data()));
})
.def("get_field", [](MlirType self, std::string fieldName) {
return hwStructTypeGetField(
self, mlirStringRefCreateFromCString(fieldName.c_str()));
.def("get_field",
[](MlirType self, std::string fieldName) {
return hwStructTypeGetField(
self, mlirStringRefCreateFromCString(fieldName.c_str()));
})
.def("get_fields", [](MlirType self) {
intptr_t num_fields = hwStructTypeGetNumFields(self);
py::list fields;
for (intptr_t i = 0; i < num_fields; ++i) {
auto field = hwStructTypeGetFieldNum(self, i);
std::string name(field.name.data, field.name.length);
fields.append(py::make_tuple(name, field.type));
}
return fields;
});
mlir_type_subclass(m, "TypeAliasType", hwTypeIsATypeAliasType)

View File

@ -79,6 +79,20 @@ MlirType hwStructTypeGetField(MlirType structType, MlirStringRef fieldName) {
return wrap(st.getFieldType(unwrap(fieldName)));
}
intptr_t hwStructTypeGetNumFields(MlirType structType) {
StructType st = unwrap(structType).cast<StructType>();
return st.getElements().size();
}
HWStructFieldInfo hwStructTypeGetFieldNum(MlirType structType, unsigned idx) {
StructType st = unwrap(structType).cast<StructType>();
auto cppField = st.getElements()[idx];
HWStructFieldInfo ret;
ret.name = wrap(cppField.name);
ret.type = wrap(cppField.type);
return ret;
}
bool hwTypeIsATypeAliasType(MlirType type) {
return unwrap(type).isa<TypeAliasType>();
}