diff --git a/include/circt/Dialect/OM/OMOps.h b/include/circt/Dialect/OM/OMOps.h index 3ca12351e1..ccbcb3bf28 100644 --- a/include/circt/Dialect/OM/OMOps.h +++ b/include/circt/Dialect/OM/OMOps.h @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #define GET_OP_CLASSES #include "circt/Dialect/OM/OM.h.inc" diff --git a/include/circt/Dialect/OM/OMOps.td b/include/circt/Dialect/OM/OMOps.td index 4bee516a55..d33b28c41c 100644 --- a/include/circt/Dialect/OM/OMOps.td +++ b/include/circt/Dialect/OM/OMOps.td @@ -14,8 +14,10 @@ #define CIRCT_DIALECT_OM_OMOPS_TD include "circt/Dialect/OM/OMOpInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/BuiltinTypes.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -195,4 +197,68 @@ def ListCreateOp : OMOp<"list_create", [Pure, SameTypeOperands]> { let hasCustomAssemblyFormat = 1; } +def TupleCreateOp : OMOp<"tuple_create", [Pure, InferTypeOpInterface]> { + let summary = "Create a tuple of values"; + let description = [{ + Create a tuple from a sequence of inputs. + + ``` + %tuple = om.tuple_create %a, %b, %c : !om.ref, !om.string, !om.list + ``` + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs + TupleOf<[AnyType]>:$result + ); + + let assemblyFormat = [{ + $inputs `:` type($inputs) attr-dict + }]; + + let extraClassDeclaration = [{ + // Implement InferTypeOpInterface. + static ::mlir::LogicalResult inferReturnTypes( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes); + }]; + +} + +def TupleGetOp : OMOp<"tuple_get", [Pure, InferTypeOpInterface]> { + let summary = "Extract a value from a tuple"; + let description = [{ + Extract a value from a tuple. + + ``` + %value = om.tuple_get %a[0] : tuple> + ``` + }]; + + let arguments = (ins + TupleOf<[AnyType]>:$input, + I32Attr:$index + ); + + let results = (outs + AnyType:$result + ); + + let assemblyFormat = [{ + $input `[` $index `]` `:` type($input) attr-dict + }]; + + let extraClassDeclaration = [{ + // Implement InferTypeOpInterface. + static ::mlir::LogicalResult inferReturnTypes( + ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes); + }]; +} + #endif // CIRCT_DIALECT_OM_OMOPS_TD diff --git a/lib/Dialect/OM/OMOps.cpp b/lib/Dialect/OM/OMOps.cpp index 1f66db69a9..1c2afc7def 100644 --- a/lib/Dialect/OM/OMOps.cpp +++ b/lib/Dialect/OM/OMOps.cpp @@ -380,6 +380,47 @@ ParseResult circt::om::ListCreateOp::parse(OpAsmParser &parser, return success(); } +//===----------------------------------------------------------------------===// +// TupleCreateOp +//===----------------------------------------------------------------------===// + +LogicalResult TupleCreateOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + ::llvm::SmallVector types; + for (auto op : operands) + types.push_back(op.getType()); + inferredReturnTypes.push_back(TupleType::get(context, types)); + return success(); +} + +//===----------------------------------------------------------------------===// +// TupleGetOp +//===----------------------------------------------------------------------===// + +LogicalResult TupleGetOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + llvm::SmallVectorImpl &inferredReturnTypes) { + auto idx = attributes.getAs("index"); + if (operands.empty() || !idx) + return failure(); + + auto tupleTypes = operands[0].getType().cast().getTypes(); + if (tupleTypes.size() <= idx.getValue().getLimitedValue()) { + if (location) + mlir::emitError(*location, + "tuple index out-of-bounds, must be less than ") + << tupleTypes.size() << " but got " + << idx.getValue().getLimitedValue(); + return failure(); + } + + inferredReturnTypes.push_back(tupleTypes[idx.getValue().getLimitedValue()]); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen generated logic. //===----------------------------------------------------------------------===// diff --git a/test/Dialect/OM/errors.mlir b/test/Dialect/OM/errors.mlir index 80308daaf2..adff8cafb1 100644 --- a/test/Dialect/OM/errors.mlir +++ b/test/Dialect/OM/errors.mlir @@ -102,3 +102,10 @@ om.class @ListCreate() { // expected-error @+1 {{map key type must be either string or integer but got '!om.list'}} om.class @Map(%map: !om.map, !om.string>) { } + +// ----- + +om.class @Tuple(%tuple: tuple) { + // expected-error @+1 {{tuple index out-of-bounds, must be less than 2 but got 2}} + %val = om.tuple_get %tuple[2] : tuple +} diff --git a/test/Dialect/OM/round-trip.mlir b/test/Dialect/OM/round-trip.mlir index b197776366..6e1c081f56 100644 --- a/test/Dialect/OM/round-trip.mlir +++ b/test/Dialect/OM/round-trip.mlir @@ -171,3 +171,15 @@ om.class @StringConstant() { om.class @Map(%map: !om.map) { om.class.field @field, %map : !om.map } + +// CHECK-LABEL: @Tuple +om.class @Tuple(%int: i1, %str: !om.string) { + // CHECK: %[[tuple:.+]] = om.tuple_create %int, %str : i1, !om.string + %tuple = om.tuple_create %int, %str : i1, !om.string + // CHECK-NEXT: om.class.field @tuple, %[[tuple]] : tuple + om.class.field @tuple, %tuple : tuple + // CHECK-NEXT: %[[tuple_get:.+]] = om.tuple_get %[[tuple]][1] : tuple + %val = om.tuple_get %tuple[1] : tuple + // CHECK-NEXT: om.class.field @val, %[[tuple_get]] : !om.string + om.class.field @val, %val : !om.string +}