[OM] Add builders for OM ops (#5244)

Add a few convenient builders for ops and attributes.
This commit is contained in:
Prithayan Barua 2023-05-23 10:40:57 -04:00 committed by GitHub
parent 9e3aa6e638
commit 26bb456406
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 31 deletions

View File

@ -43,6 +43,13 @@ def OMSymbolRefAttr : AttrDef<OMDialect, "SymbolRef", [TypedAttrInterface]> {
"mlir::FlatSymbolRefAttr":$ref
);
let builders = [
// Get the SymbolRefAttr to the symbol represented by this operation.
AttrBuilderWithInferredContext<(ins "::mlir::Operation *":$op)>,
// Get the SymbolRefAttr to this symbol name.
AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$symName)>
];
let assemblyFormat = [{
`<` $ref `>`
}];

View File

@ -39,6 +39,12 @@ def ClassOp : OMOp<"class",
SizedRegion<1>:$body
);
let builders = [
OpBuilder<(ins "::mlir::Twine":$name,
"::mlir::ArrayRef<::mlir::StringRef>":$formalParamNames)>,
OpBuilder<(ins "::mlir::Twine":$name)>
];
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
@ -75,6 +81,10 @@ def ObjectOp : OMOp<"object",
ClassType:$result
);
let builders = [
OpBuilder<(ins "om::ClassOp":$classOp, "::mlir::ValueRange":$actualParams)>
];
let assemblyFormat = [{
$className `(` $actualParams `)` `:`
functional-type($actualParams, $result) attr-dict
@ -111,6 +121,10 @@ def ConstantOp : OMOp<"constant",
AnyType:$result
);
let builders = [
OpBuilder<(ins "::mlir::TypedAttr":$constVal)>
];
let assemblyFormat = [{
$value attr-dict
}];

View File

@ -31,6 +31,17 @@ Type circt::om::SymbolRefAttr::getType() {
return SymbolRefType::get(getContext());
}
circt::om::SymbolRefAttr circt::om::SymbolRefAttr::get(mlir::Operation *op) {
return om::SymbolRefAttr::get(op->getContext(),
mlir::FlatSymbolRefAttr::get(op));
}
circt::om::SymbolRefAttr
circt::om::SymbolRefAttr::get(mlir::StringAttr symName) {
return om::SymbolRefAttr::get(symName.getContext(),
mlir::FlatSymbolRefAttr::get(symName));
}
void circt::om::OMDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST

View File

@ -58,6 +58,19 @@ ParseResult circt::om::ClassOp::parse(OpAsmParser &parser,
return success();
}
void circt::om::ClassOp::build(OpBuilder &odsBuilder, OperationState &odsState,
Twine name,
ArrayRef<StringRef> formalParamNames) {
return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
odsBuilder.getStrArrayAttr(formalParamNames));
}
void circt::om::ClassOp::build(OpBuilder &odsBuilder, OperationState &odsState,
Twine name) {
return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
odsBuilder.getStrArrayAttr({}));
}
void circt::om::ClassOp::print(OpAsmPrinter &printer) {
// Print the Class symbol name.
printer << " @";
@ -119,6 +132,16 @@ void circt::om::ClassOp::getAsmBlockArgumentNames(
// ObjectOp
//===----------------------------------------------------------------------===//
void circt::om::ObjectOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
om::ClassOp classOp,
::mlir::ValueRange actualParams) {
return build(odsBuilder, odsState,
om::ClassType::get(odsBuilder.getContext(),
mlir::FlatSymbolRefAttr::get(classOp)),
classOp.getNameAttr(), actualParams);
}
LogicalResult
circt::om::ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Get the containing ModuleOp.
@ -223,6 +246,16 @@ circt::om::ObjectFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
void circt::om::ConstantOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
::mlir::TypedAttr constVal) {
return build(odsBuilder, odsState, constVal.getType(), constVal);
}
//===----------------------------------------------------------------------===//
// TableGen generated logic.
//===----------------------------------------------------------------------===//

View File

@ -62,7 +62,7 @@ TEST(EvaluatorTests, InstantiateInvalidParamSize) {
auto mod = builder.create<ModuleOp>(loc);
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
ArrayAttr params = builder.getStrArrayAttr({"param"});
StringRef params[] = {"param"};
auto cls = builder.create<ClassOp>("MyClass", params);
cls.getBody().emplaceBlock().addArgument(builder.getIntegerType(32),
cls.getLoc());
@ -95,7 +95,7 @@ TEST(EvaluatorTests, InstantiateNullParam) {
auto mod = builder.create<ModuleOp>(loc);
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
ArrayAttr params = builder.getStrArrayAttr({"param"});
StringRef params[] = {"param"};
auto cls = builder.create<ClassOp>("MyClass", params);
cls.getBody().emplaceBlock().addArgument(builder.getIntegerType(32),
cls.getLoc());
@ -126,7 +126,7 @@ TEST(EvaluatorTests, InstantiateInvalidParamType) {
auto mod = builder.create<ModuleOp>(loc);
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
ArrayAttr params = builder.getStrArrayAttr({"param"});
StringRef params[] = {"param"};
auto cls = builder.create<ClassOp>("MyClass", params);
cls.getBody().emplaceBlock().addArgument(builder.getIntegerType(32),
cls.getLoc());
@ -157,7 +157,7 @@ TEST(EvaluatorTests, GetFieldInvalidName) {
auto mod = builder.create<ModuleOp>(loc);
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
auto cls = builder.create<ClassOp>("MyClass", builder.getArrayAttr({}));
auto cls = builder.create<ClassOp>("MyClass");
cls.getBody().emplaceBlock();
Evaluator evaluator(mod);
@ -191,7 +191,7 @@ TEST(EvaluatorTests, InstantiateObjectWithParamField) {
auto mod = builder.create<ModuleOp>(loc);
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
ArrayAttr params = builder.getStrArrayAttr({"param"});
StringRef params[] = {"param"};
auto cls = builder.create<ClassOp>("MyClass", params);
auto &body = cls.getBody().emplaceBlock();
body.addArgument(builder.getIntegerType(32), cls.getLoc());
@ -227,11 +227,10 @@ TEST(EvaluatorTests, InstantiateObjectWithConstantField) {
auto mod = builder.create<ModuleOp>(loc);
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
auto cls = builder.create<ClassOp>("MyClass", builder.getArrayAttr({}));
auto cls = builder.create<ClassOp>("MyClass");
auto &body = cls.getBody().emplaceBlock();
builder.setInsertionPointToStart(&body);
auto constant = builder.create<ConstantOp>(builder.getIntegerType(32),
builder.getI32IntegerAttr(42));
auto constant = builder.create<ConstantOp>(builder.getI32IntegerAttr(42));
builder.create<ClassFieldOp>("field", constant);
Evaluator evaluator(mod);
@ -262,7 +261,7 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObject) {
auto mod = builder.create<ModuleOp>(loc);
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
ArrayAttr params = builder.getStrArrayAttr({"param"});
StringRef params[] = {"param"};
auto innerCls = builder.create<ClassOp>("MyInnerClass", params);
auto &innerBody = innerCls.getBody().emplaceBlock();
innerBody.addArgument(builder.getIntegerType(32), innerCls.getLoc());
@ -274,12 +273,7 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObject) {
auto &body = cls.getBody().emplaceBlock();
body.addArgument(builder.getIntegerType(32), cls.getLoc());
builder.setInsertionPointToStart(&body);
auto innerClsType = ClassType::get(
builder.getContext(),
FlatSymbolRefAttr::get(builder.getStringAttr("MyInnerClass")));
auto innerClsName = builder.getStringAttr("MyInnerClass");
auto object =
builder.create<ObjectOp>(innerClsType, innerClsName, body.getArguments());
auto object = builder.create<ObjectOp>(innerCls, body.getArguments());
builder.create<ClassFieldOp>("field", object);
Evaluator evaluator(mod);
@ -316,7 +310,7 @@ TEST(EvaluatorTests, InstantiateObjectWithFieldAccess) {
auto mod = builder.create<ModuleOp>(loc);
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
ArrayAttr params = builder.getStrArrayAttr({"param"});
StringRef params[] = {"param"};
auto innerCls = builder.create<ClassOp>("MyInnerClass", params);
auto &innerBody = innerCls.getBody().emplaceBlock();
innerBody.addArgument(builder.getIntegerType(32), innerCls.getLoc());
@ -328,12 +322,7 @@ TEST(EvaluatorTests, InstantiateObjectWithFieldAccess) {
auto &body = cls.getBody().emplaceBlock();
body.addArgument(builder.getIntegerType(32), cls.getLoc());
builder.setInsertionPointToStart(&body);
auto innerClsType = ClassType::get(
builder.getContext(),
FlatSymbolRefAttr::get(builder.getStringAttr("MyInnerClass")));
auto innerClsName = builder.getStringAttr("MyInnerClass");
auto object =
builder.create<ObjectOp>(innerClsType, innerClsName, body.getArguments());
auto object = builder.create<ObjectOp>(innerCls, body.getArguments());
auto field =
builder.create<ObjectFieldOp>(builder.getI32Type(), object,
builder.getArrayAttr(FlatSymbolRefAttr::get(
@ -370,20 +359,14 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObjectMemoized) {
auto mod = builder.create<ModuleOp>(loc);
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
auto innerCls =
builder.create<ClassOp>("MyInnerClass", builder.getArrayAttr({}));
auto innerCls = builder.create<ClassOp>("MyInnerClass");
innerCls.getBody().emplaceBlock();
builder.setInsertionPointToStart(&mod.getBodyRegion().front());
auto cls = builder.create<ClassOp>("MyClass", builder.getArrayAttr({}));
auto cls = builder.create<ClassOp>("MyClass");
auto &body = cls.getBody().emplaceBlock();
builder.setInsertionPointToStart(&body);
auto innerClsType = ClassType::get(
builder.getContext(),
FlatSymbolRefAttr::get(builder.getStringAttr("MyInnerClass")));
auto innerClsName = builder.getStringAttr("MyInnerClass");
auto object =
builder.create<ObjectOp>(innerClsType, innerClsName, body.getArguments());
auto object = builder.create<ObjectOp>(innerCls, body.getArguments());
builder.create<ClassFieldOp>("field1", object);
builder.create<ClassFieldOp>("field2", object);