[NFC] Use ModuleType to compute many properties (#5715)

In anticipation of most modules converting from functions to portlists using module type, route most of the HWModuleLike interface through the ModuleType.  Modules like things must now be able to produce a moduletype from which many accessors are derived.  getNumPorts is separate from this as non-function-like modules need this to compute the module type.
This commit is contained in:
Andrew Lenharth 2023-08-15 14:41:55 -07:00 committed by GitHub
parent 24f3858255
commit aa4942baaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 129 additions and 97 deletions

View File

@ -82,8 +82,17 @@ class FIRRTLModuleLike<string mnemonic, list<Trait> traits = []> :
}];
let extraClassDefinition = extraModuleClassDefinition # [{
size_t $cppClass::getNumPorts() {
return getPortTypesAttr().size();
circt::hw::ModuleType $cppClass::getHWModuleType() {
SmallVector<hw::ModulePort> newPorts;
auto num = getPortNames().size();
newPorts.reserve(num);
for (unsigned i = 0; i < num; ++i)
newPorts.push_back({getPortNameAttr(i), getPortType(i),
getPortDirection(i) == Direction::In ?
hw::ModulePort::Direction::Input :
hw::ModulePort::Direction::Output});
return hw::ModuleType::get(getContext(), newPorts);
}
circt::hw::InnerSymAttr $cppClass::getPortSymbolAttr(size_t portIndex) {

View File

@ -32,79 +32,60 @@ def HWModuleLike : OpInterface<"HWModuleLike", [
let description = "Provide common module information.";
let methods = [
InterfaceMethod<"Get the number of ports",
"size_t", "getNumPorts">,
InterfaceMethod<"Get a port symbol attribute",
"::circt::hw::InnerSymAttr", "getPortSymbolAttr", (ins "size_t":$portIndex)>,
InterfaceMethod<"Get the module type",
"::circt::hw::ModuleType", "getHWModuleType", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{ return ::circt::hw::detail::fnToMod($_op, getInputNames(), getOutputNames()); }]>,
InterfaceMethod<"Return the names of the inputs to this module",
"mlir::ArrayAttr", "getInputNames", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
return $_op->template getAttrOfType<ArrayAttr>("argNames");
}]>,
InterfaceMethod<"Return the names of the outputs this module",
"mlir::ArrayAttr", "getOutputNames", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
return $_op->template getAttrOfType<ArrayAttr>("resultNames");
}]>,
InterfaceMethod<"Return the locations of the inputs to this module",
"mlir::ArrayAttr", "getInputLocs", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
return $_op->template getAttrOfType<ArrayAttr>("argLocs");
}]>,
InterfaceMethod<"Return the locations of the outputs of this module",
"mlir::ArrayAttr", "getOutputLocs", (ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
return $_op->template getAttrOfType<ArrayAttr>("resultLocs");
}]>,
"::circt::hw::ModuleType", "getHWModuleType", (ins)>,
];
let extraSharedClassDeclaration = [{
// Return the number of inputs to this module
unsigned getNumInputs() {
return getInputNames().size();
/// Return the total number of ports in the module
size_t getNumPorts() {
return $_op.getHWModuleType().getNumPorts();
}
// Return the number of outputs from this module
unsigned getNumOutputs() {
return getOutputNames().size();
/// Return the total number of input and inout ports in the module
size_t getNumInputs() {
return $_op.getHWModuleType().getNumInputs();
}
// Return the name of an input.
mlir::StringRef getInputName(unsigned idx) {
return getInputNameAttr(idx).getValue();
/// Return the total number of output ports in the module
size_t getNumOutputs() {
return $_op.getHWModuleType().getNumOutputs();
}
// Return the name of an output
mlir::StringRef getOutputName(unsigned idx) {
return getOutputNameAttr(idx).getValue();
/// Return the set of names on input and inout ports
SmallVector<StringAttr> getInputNames() {
return $_op.getHWModuleType().getInputNames();
}
// Return the name of an input to this module
mlir::StringAttr getInputNameAttr(unsigned idx) {
return cast<StringAttr>(getInputNames()[idx]);
/// Return the set of names on output ports
SmallVector<StringAttr> getOutputNames() {
return $_op.getHWModuleType().getOutputNames();
}
// Get the name for the specified input or inout port
StringRef getInputName(size_t idx) {
return $_op.getHWModuleType().getInputName(idx);
}
// Return the name of an output to this module
mlir::StringAttr getOutputNameAttr(unsigned idx) {
return cast<StringAttr>(getOutputNames()[idx]);
// Get the name for the specified output port
StringRef getOutputName(size_t idx) {
return $_op.getHWModuleType().getOutputName(idx);
}
StringAttr getInputNameAttr(size_t idx) {
return $_op.getHWModuleType().getInputNameAttr(idx);
}
StringAttr getOutputNameAttr(size_t idx) {
return $_op.getHWModuleType().getOutputNameAttr(idx);
}
}];
let verify = [{

View File

@ -67,9 +67,9 @@ class HWModuleOpBase<string mnemonic, list<Trait> traits = []> :
code extraModuleClassDefinition = [{}];
let extraClassDefinition = extraModuleClassDefinition # [{
size_t $cppClass::getNumPorts() {
// This is wildly inefficient
return getPortList().size();
ModuleType $cppClass::getHWModuleType() {
return detail::fnToMod(getFunctionType(), getArgNames(), getResultNames());
}
::circt::hw::InnerSymAttr $cppClass::getPortSymbolAttr(size_t portIndex) {

View File

@ -241,6 +241,7 @@ def ModuleTypeImpl : HWType<"Module"> {
let extraClassDeclaration = [{
// Many of these are transitional and will be removed when modules and instances
// have moved over to this type.
size_t getNumPorts();
size_t getNumInputs();
size_t getNumOutputs();
SmallVector<Type> getInputTypes();

View File

@ -179,10 +179,9 @@ def MSFTModuleOp : MSFTModuleOpBase<"module",
}];
let extraClassDefinition = [{
size_t $cppClass::getNumPorts() {
// This is excessively expensive
auto ports = getPortList();
return ports.size();
::circt::hw::ModuleType $cppClass::getHWModuleType() {
return ::circt::hw::detail::fnToMod(getFunctionType(), getArgNames(), getResultNames());
}
circt::hw::InnerSymAttr $cppClass::getPortSymbolAttr(size_t portIndex) {
@ -226,6 +225,10 @@ def MSFTModuleExternOp : MSFTOp<"module.extern",
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Decode information about the input and output ports on this module.
hw::ModulePortInfo getPorts();
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
@ -237,6 +240,12 @@ def MSFTModuleExternOp : MSFTOp<"module.extern",
//===------------------------------------------------------------------===//
::circt::hw::ModulePortInfo getPortList();
}];
let extraClassDefinition = [{
::circt::hw::ModuleType $cppClass::getHWModuleType() {
return ::circt::hw::detail::fnToMod(getFunctionType(), getArgNames(), getResultNames());
}
}];
}
def DesignPartitionOp : MSFTOp<"partition",

View File

@ -530,16 +530,18 @@ LogicalResult ESIPureModuleOp::verify() {
return success();
}
hw::ModulePortInfo ESIPureModuleOp::getPortList() {
return hw::ModulePortInfo({});
hw::ModuleType ESIPureModuleOp::getHWModuleType() {
return hw::ModuleType::get(getContext(), {});
}
size_t ESIPureModuleOp::getNumPorts() { return 0; }
hw::InnerSymAttr ESIPureModuleOp::getPortSymbolAttr(size_t portIndex) {
assert(false);
::circt::hw::InnerSymAttr ESIPureModuleOp::getPortSymbolAttr(size_t) {
return {};
}
::circt::hw::ModulePortInfo ESIPureModuleOp::getPortList() {
return hw::ModulePortInfo(ArrayRef<hw::PortInfo>{});
}
#define GET_OP_CLASSES
#include "circt/Dialect/ESI/ESI.cpp.inc"

View File

@ -615,13 +615,19 @@ ESIConnectServicesPass::surfaceReqs(hw::HWMutableModuleLike mod,
// Create a replacement instance of the same operation type.
SmallVector<NamedAttribute> newAttrs;
for (auto attr : inst->getAttrs()) {
if (attr.getName() == argsAttrName)
newAttrs.push_back(b.getNamedAttr(argsAttrName, mod.getInputNames()));
else if (attr.getName() == resultsAttrName)
if (attr.getName() == argsAttrName) {
auto names = mod.getInputNames();
SmallVector<Attribute> nnames(names.begin(), names.end());
newAttrs.push_back(
b.getNamedAttr(resultsAttrName, mod.getOutputNames()));
else
b.getNamedAttr(argsAttrName, b.getArrayAttr(nnames)));
} else if (attr.getName() == resultsAttrName) {
auto names = mod.getOutputNames();
SmallVector<Attribute> nnames(names.begin(), names.end());
newAttrs.push_back(
b.getNamedAttr(resultsAttrName, b.getArrayAttr(nnames)));
} else {
newAttrs.push_back(attr);
}
}
auto *newInst = b.insert(Operation::create(
inst->getLoc(), inst->getName(), newResultTypes, newOperands,

View File

@ -110,7 +110,7 @@ static bool applyToPort(AnnotationSet annos, Operation *op, size_t portCount,
}
bool AnnotationSet::applyToPort(FModuleLike op, size_t portNo) const {
return ::applyToPort(*this, op.getOperation(), getNumPorts(op), portNo);
return ::applyToPort(*this, op.getOperation(), op.getNumPorts(), portNo);
}
bool AnnotationSet::applyToPort(MemOp op, size_t portNo) const {

View File

@ -221,7 +221,7 @@ DeclKind firrtl::getDeclarationKind(Value val) {
}
size_t firrtl::getNumPorts(Operation *op) {
if (auto module = dyn_cast<hw::HWModuleLike>(op))
if (auto module = dyn_cast<FModuleLike>(op))
return module.getNumPorts();
return op->getNumResults();
}
@ -458,7 +458,7 @@ Block *CircuitOp::getBodyBlock() { return &getBody().front(); }
static SmallVector<PortInfo> getPortImpl(FModuleLike module) {
SmallVector<PortInfo> results;
for (unsigned i = 0, e = getNumPorts(module); i < e; ++i) {
for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i) {
results.push_back({module.getPortNameAttr(i), module.getPortType(i),
module.getPortDirection(i), module.getPortSymbolAttr(i),
module.getPortLocation(i),
@ -523,7 +523,7 @@ static void insertPorts(FModuleLike op,
ArrayRef<std::pair<unsigned, PortInfo>> ports) {
if (ports.empty())
return;
unsigned oldNumArgs = getNumPorts(op);
unsigned oldNumArgs = op.getNumPorts();
unsigned newNumArgs = oldNumArgs + ports.size();
// Add direction markers and names for new ports.
@ -604,7 +604,7 @@ static void erasePorts(FModuleLike op, const llvm::BitVector &portIndices) {
ArrayRef<Attribute> portAnnos = op.getPortAnnotations();
ArrayRef<Attribute> portSyms = op.getPortSymbols();
ArrayRef<Attribute> portLocs = op.getPortLocations();
auto numPorts = getNumPorts(op);
auto numPorts = op.getNumPorts();
(void)numPorts;
assert(portDirections.size() == numPorts);
assert(portNames.size() == numPorts);
@ -1710,7 +1710,7 @@ void InstanceOp::build(OpBuilder &builder, OperationState &result,
// Gather the result types.
SmallVector<Type> resultTypes;
resultTypes.reserve(getNumPorts(module));
resultTypes.reserve(module.getNumPorts());
llvm::transform(
module.getPortTypes(), std::back_inserter(resultTypes),
[](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
@ -1860,7 +1860,7 @@ LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that all the attribute arrays are the right length up front. This
// lets us safely use the port name in error messages below.
size_t numResults = getNumResults();
size_t numExpected = getNumPorts(referencedModule);
size_t numExpected = referencedModule.getNumPorts();
if (numResults != numExpected) {
return emitNote(emitOpError() << "has a wrong number of results; expected "
<< numExpected << " but got " << numResults);

View File

@ -867,7 +867,7 @@ void Inliner::mapPortsToWires(StringRef prefix, InliningLevel &il,
const DenseSet<Attribute> &localSymbols) {
auto target = il.childModule;
auto portInfo = target.getPorts();
for (unsigned i = 0, e = getNumPorts(target); i < e; ++i) {
for (unsigned i = 0, e = target.getNumPorts(); i < e; ++i) {
auto arg = target.getArgument(i);
// Get the type of the wire.
auto type = type_cast<FIRRTLType>(arg.getType());
@ -1331,7 +1331,7 @@ void Inliner::identifyNLAsTargetingOnlyModules() {
referencedNLASyms.insert(sym.getAttr());
};
// Scan ports
for (unsigned i = 0, e = getNumPorts(mod); i != e; ++i)
for (unsigned i = 0, e = mod.getNumPorts(); i != e; ++i)
scanAnnos(AnnotationSet::forPort(mod, i));
// Scan operations (and not the module itself):

View File

@ -76,6 +76,8 @@ void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
for (unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
hw::PortInfo port;
port.name = getArgName(i);
if (!port.name)
port.name = StringAttr::get(getContext(), "in" + std::to_string(i));
port.dir = circt::hw::ModulePort::Direction::Input;
port.type = machineType.getInput(i);
port.argNum = i;
@ -85,6 +87,8 @@ void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
for (unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
hw::PortInfo port;
port.name = getResName(i);
if (!port.name)
port.name = StringAttr::get(getContext(), "out" + std::to_string(i));
port.dir = circt::hw::ModulePort::Direction::Output;
port.type = machineType.getResult(i);
port.argNum = i;
@ -94,9 +98,9 @@ void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
[&](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,

View File

@ -603,9 +603,19 @@ LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
size_t ModuleType::getNumInputs() { return getInputTypes().size(); }
size_t ModuleType::getNumInputs() {
return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
return p.dir != ModulePort::Direction::Output;
});
}
size_t ModuleType::getNumOutputs() { return getOutputTypes().size(); }
size_t ModuleType::getNumOutputs() {
return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
return p.dir == ModulePort::Direction::Output;
});
}
size_t ModuleType::getNumPorts() { return getPorts().size(); }
SmallVector<Type> ModuleType::getInputTypes() {
SmallVector<Type> retval;
@ -788,14 +798,28 @@ ModuleType circt::hw::detail::fnToMod(Operation *op, ArrayAttr inputNames,
ModuleType circt::hw::detail::fnToMod(FunctionType fnty, ArrayAttr inputNames,
ArrayAttr outputNames) {
SmallVector<ModulePort> ports;
for (auto [t, n] : llvm::zip(fnty.getInputs(), inputNames))
if (auto iot = dyn_cast<hw::InOutType>(t))
ports.push_back({cast<StringAttr>(n), iot.getElementType(),
ModulePort::Direction::InOut});
else
ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
for (auto [t, n] : llvm::zip(fnty.getResults(), outputNames))
ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
if (inputNames) {
for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames))
if (auto iot = dyn_cast<hw::InOutType>(t))
ports.push_back({cast<StringAttr>(n), iot.getElementType(),
ModulePort::Direction::InOut});
else
ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
} else {
for (auto t : fnty.getInputs())
if (auto iot = dyn_cast<hw::InOutType>(t))
ports.push_back(
{{}, iot.getElementType(), ModulePort::Direction::InOut});
else
ports.push_back({{}, t, ModulePort::Direction::Input});
}
if (outputNames) {
for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames))
ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
} else {
for (auto t : fnty.getResults())
ports.push_back({{}, t, ModulePort::Direction::Output});
}
return ModuleType::get(fnty.getContext(), ports);
}

View File

@ -965,10 +965,6 @@ hw::ModulePortInfo MSFTModuleExternOp::getPortList() {
return hw::ModulePortInfo(inputs, outputs);
}
size_t MSFTModuleExternOp::getNumPorts() {
return getArgNames().size() + getResultNames().size();
}
hw::InnerSymAttr MSFTModuleExternOp::getPortSymbolAttr(size_t) { return {}; }
//===----------------------------------------------------------------------===//