[MSFT] Add `MSFTModuleOp` (#1801)

Adds a moduleop to the MSFT dialect. For now, it's mostly like hw.module but adds parameters to represent the specific parameterization of a parameterized module. Also necessarily adds a msft.output since hw.output expects to be in a hw.module.

Step 3 of #1755.
This commit is contained in:
John Demme 2021-09-16 12:48:39 -07:00 committed by GitHub
parent 50904ee377
commit 3f738e2017
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 248 additions and 0 deletions

View File

@ -15,9 +15,11 @@
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Pass/PassBase.td"
include "mlir/IR/RegionKindInterface.td"
def MSFTDialect : Dialect {
let name = "msft";

View File

@ -38,3 +38,83 @@ def InstanceOp : MSFTOp<"instance", [
`:` functional-type($inputs, results)
}];
}
def OneOrNoBlocksRegion : Region<
CPred<"::llvm::hasNItemsOrLess($_self, 1)">,
"region with at most 1 block">;
def MSFTModuleOp : MSFTOp<"module",
[IsolatedFromAbove, FunctionLike, Symbol, RegionKindInterface,
HasParent<"mlir::ModuleOp">]>{
let summary = "MSFT HW Module";
let description = [{
A lot like `hw.module`, but with a few differences:
- Can exist without a body. The body is filled in by a generator post op
creation.
- MSFT-specific methods and arguments will be added later on.
}];
let arguments = (ins
StrArrayAttr:$argNames, StrArrayAttr:$resultNames,
DictionaryAttr:$parameters);
let results = (outs);
let regions = (region OneOrNoBlocksRegion:$body);
let extraClassDeclaration = [{
using FunctionLike::front;
using FunctionLike::getBody;
// Implement RegionKindInterface.
static RegionKind getRegionKind(unsigned index) {
return RegionKind::Graph;
}
// Decode information about the input and output ports on this module.
SmallVector<::circt::hw::ModulePortInfo> getPorts();
// Get the module's symbolic name as StringAttr.
StringAttr getNameAttr() {
return (*this)->getAttrOfType<StringAttr>(
::mlir::SymbolTable::getSymbolAttrName());
}
// Get the module's symbolic name.
StringRef getName() {
return getNameAttr().getValue();
}
private:
// This trait needs access to the hooks defined below.
friend class OpTrait::FunctionLike<MSFTModuleOp>;
/// Returns the number of arguments, implementing OpTrait::FunctionLike.
unsigned getNumFuncArguments() { return getType().getInputs().size(); }
/// Returns the number of results, implementing OpTrait::FunctionLike.
unsigned getNumFuncResults() { return getType().getResults().size(); }
/// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
/// attribute is present and checks if it holds a function type. Ensures
/// getType, getNumFuncArguments, and getNumFuncResults can be called
/// safely.
LogicalResult verifyType() {
auto type = getTypeAttr().getValue();
if (!type.isa<FunctionType>())
return emitOpError("requires '" + getTypeAttrName() +
"' attribute of function type");
return success();
}
public:
}];
let printer = "return ::print$cppClass(p, *this);";
let parser = "return ::parse$cppClass(parser, result);";
}
def OutputOp : MSFTOp<"output", [Terminator, HasParent<"MSFTModuleOp">,
NoSideEffect, ReturnLike]> {
let summary = "termination operation";
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}

View File

@ -24,12 +24,38 @@ using namespace msft;
// Dialect specification.
//===----------------------------------------------------------------------===//
namespace {
// We implement the OpAsmDialectInterface so that MSFT dialect operations
// automatically interpret the name attribute on operations as their SSA name.
struct MSFTOpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
/// Get a special name to use when printing the entry block arguments of the
/// region contained by an operation in this dialect.
void getAsmBlockArgumentNames(Block *block,
OpAsmSetValueNameFn setNameFn) const override {
// Assign port names to the bbargs if this is a module.
auto modOp = dyn_cast<MSFTModuleOp>(block->getParentOp());
if (!modOp)
return;
ArrayAttr argNames = modOp.argNamesAttr();
for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
auto name = argNames[i].cast<StringAttr>().getValue();
if (!name.empty())
setNameFn(block->getArgument(i), name);
}
}
};
} // end anonymous namespace
void MSFTDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "circt/Dialect/MSFT/MSFT.cpp.inc"
>();
registerAttributes();
addInterfaces<MSFTOpAsmDialectInterface>();
}
/// Registered hook to materialize a single constant operation from a given
@ -81,5 +107,6 @@ hw::InstanceOp circt::msft::getInstance(hw::HWModuleOp root,
path.push_back(sym.getValue());
return ::getInstance(root, path);
}
#include "circt/Dialect/MSFT/MSFTDialect.cpp.inc"
#include "circt/Dialect/MSFT/MSFTEnums.cpp.inc"

View File

@ -12,9 +12,11 @@
#include "circt/Dialect/MSFT/MSFTOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/ModuleImplementation.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace circt;
@ -62,5 +64,129 @@ LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
for (auto &argAttr : attrs)
if (argAttr.first == name)
return true;
return false;
}
static ParseResult parseMSFTModuleOp(OpAsmParser &parser,
OperationState &result) {
using namespace mlir::function_like_impl;
auto loc = parser.getCurrentLocation();
SmallVector<OpAsmParser::OperandType, 4> entryArgs;
SmallVector<NamedAttrList, 4> argAttrs;
SmallVector<NamedAttrList, 4> resultAttrs;
SmallVector<Type, 4> argTypes;
SmallVector<Type, 4> resultTypes;
auto &builder = parser.getBuilder();
// Parse the name as a symbol.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
// Parse the parameters
DictionaryAttr paramsAttr;
if (parser.parseAttribute(paramsAttr))
return failure();
result.addAttribute("parameters", paramsAttr);
// Parse the function signature.
bool isVariadic = false;
SmallVector<Attribute> resultNames;
if (hw::module_like_impl::parseModuleFunctionSignature(
parser, entryArgs, argTypes, argAttrs, isVariadic, resultTypes,
resultAttrs, resultNames))
return failure();
// Record the argument and result types as an attribute. This is necessary
// for external modules.
auto type = builder.getFunctionType(argTypes, resultTypes);
result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
// If function attributes are present, parse them.
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
auto *context = result.getContext();
if (hasAttribute("argNames", result.attributes) ||
hasAttribute("resultNames", result.attributes)) {
parser.emitError(
loc, "explicit argNames and resultNames attributes not allowed");
return failure();
}
// Use the argument and result names if not already specified.
SmallVector<Attribute> argNames;
if (!entryArgs.empty()) {
for (auto &arg : entryArgs)
argNames.push_back(
hw::module_like_impl::getPortNameAttr(context, arg.name));
} else if (!argTypes.empty()) {
// The parser returns empty names in a special way.
argNames.assign(argTypes.size(), StringAttr::get(context, ""));
}
result.addAttribute("argNames", ArrayAttr::get(context, argNames));
result.addAttribute("resultNames", ArrayAttr::get(context, resultNames));
assert(argAttrs.size() == argTypes.size());
assert(resultAttrs.size() == resultTypes.size());
// Add the attributes to the module arguments.
addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
// Parse the optional module body.
auto regionSuccess = parser.parseOptionalRegion(
*result.addRegion(), entryArgs,
entryArgs.empty() ? ArrayRef<Type>() : argTypes);
if (regionSuccess.hasValue() && failed(*regionSuccess))
return failure();
return success();
}
static void printMSFTModuleOp(OpAsmPrinter &p, MSFTModuleOp mod) {
using namespace mlir::function_like_impl;
FunctionType fnType = mod.getType();
auto argTypes = fnType.getInputs();
auto resultTypes = fnType.getResults();
// Print the operation and the function name.
p << ' ';
p.printSymbolName(SymbolTable::getSymbolName(mod).getValue());
// Print the parameterization.
p << ' ';
p.printAttribute(mod.parametersAttr());
p << ' ';
bool needArgNamesAttr = false;
hw::module_like_impl::printModuleSignature(
p, mod, argTypes, /*isVariadic=*/false, resultTypes, needArgNamesAttr);
SmallVector<StringRef, 3> omittedAttrs;
if (!needArgNamesAttr)
omittedAttrs.push_back("argNames");
omittedAttrs.push_back("resultNames");
omittedAttrs.push_back("parameters");
printFunctionAttributes(p, mod, argTypes.size(), resultTypes.size(),
omittedAttrs);
// Print the body if this is not an external function.
Region &body = mod.getBody();
if (!body.empty())
p.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
#define GET_OP_CLASSES
#include "circt/Dialect/MSFT/MSFT.cpp.inc"

View File

@ -191,6 +191,8 @@ void LowerToHWPass::runOnOperation() {
// Set up a conversion and give it a set of laws.
ConversionTarget target(*ctxt);
target.addIllegalDialect<MSFTDialect>();
target.addLegalOp<MSFTModuleOp>(); // TODO: Remove me!
target.addLegalOp<OutputOp>(); // TODO: Remove me!
target.addLegalDialect<hw::HWDialect>();
// Add all the conversion patterns.

View File

@ -10,3 +10,14 @@ hw.module @top () {
// CHECK: %foo.x = msft.instance "foo" @fooMod() : () -> i32
// HWLOW: %foo.x = hw.instance "foo" @fooMod() -> (x: i32)
}
// CHECK-LABEL: msft.module @B {WIDTH = 1 : i64} (%a: i1) -> (nameOfPortInSV: i1) {
msft.module @B { "WIDTH" = 1 } (%a: i1) -> (nameOfPortInSV: i1) {
%0 = comb.or %a, %a : i1
// CHECK: comb.or %a, %a : i1
%1 = comb.and %a, %a : i1
msft.output %0, %1: i1, i1
}
// CHECK-LABEL: msft.module @UnGenerated {DEPTH = 3 : i64} (%a: i1) -> (nameOfPortInSV: i1)
msft.module @UnGenerated { DEPTH = 3 } (%a: i1) -> (nameOfPortInSV: i1)