mirror of https://github.com/llvm/circt.git
[NFC] Use ModuleType to compute many properties (#5715)
In anticipation of most modules converting from functions to portlists using module type, route most of the HWModuleLike interface through the ModuleType. Modules like things must now be able to produce a moduletype from which many accessors are derived. getNumPorts is separate from this as non-function-like modules need this to compute the module type.
This commit is contained in:
parent
24f3858255
commit
aa4942baaa
|
@ -82,8 +82,17 @@ class FIRRTLModuleLike<string mnemonic, list<Trait> traits = []> :
|
|||
}];
|
||||
|
||||
let extraClassDefinition = extraModuleClassDefinition # [{
|
||||
size_t $cppClass::getNumPorts() {
|
||||
return getPortTypesAttr().size();
|
||||
|
||||
circt::hw::ModuleType $cppClass::getHWModuleType() {
|
||||
SmallVector<hw::ModulePort> newPorts;
|
||||
auto num = getPortNames().size();
|
||||
newPorts.reserve(num);
|
||||
for (unsigned i = 0; i < num; ++i)
|
||||
newPorts.push_back({getPortNameAttr(i), getPortType(i),
|
||||
getPortDirection(i) == Direction::In ?
|
||||
hw::ModulePort::Direction::Input :
|
||||
hw::ModulePort::Direction::Output});
|
||||
return hw::ModuleType::get(getContext(), newPorts);
|
||||
}
|
||||
|
||||
circt::hw::InnerSymAttr $cppClass::getPortSymbolAttr(size_t portIndex) {
|
||||
|
|
|
@ -32,79 +32,60 @@ def HWModuleLike : OpInterface<"HWModuleLike", [
|
|||
let description = "Provide common module information.";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Get the number of ports",
|
||||
"size_t", "getNumPorts">,
|
||||
|
||||
InterfaceMethod<"Get a port symbol attribute",
|
||||
"::circt::hw::InnerSymAttr", "getPortSymbolAttr", (ins "size_t":$portIndex)>,
|
||||
|
||||
InterfaceMethod<"Get the module type",
|
||||
"::circt::hw::ModuleType", "getHWModuleType", (ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{ return ::circt::hw::detail::fnToMod($_op, getInputNames(), getOutputNames()); }]>,
|
||||
|
||||
InterfaceMethod<"Return the names of the inputs to this module",
|
||||
"mlir::ArrayAttr", "getInputNames", (ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op->template getAttrOfType<ArrayAttr>("argNames");
|
||||
}]>,
|
||||
|
||||
InterfaceMethod<"Return the names of the outputs this module",
|
||||
"mlir::ArrayAttr", "getOutputNames", (ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op->template getAttrOfType<ArrayAttr>("resultNames");
|
||||
}]>,
|
||||
|
||||
InterfaceMethod<"Return the locations of the inputs to this module",
|
||||
"mlir::ArrayAttr", "getInputLocs", (ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op->template getAttrOfType<ArrayAttr>("argLocs");
|
||||
}]>,
|
||||
|
||||
InterfaceMethod<"Return the locations of the outputs of this module",
|
||||
"mlir::ArrayAttr", "getOutputLocs", (ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op->template getAttrOfType<ArrayAttr>("resultLocs");
|
||||
}]>,
|
||||
|
||||
"::circt::hw::ModuleType", "getHWModuleType", (ins)>,
|
||||
];
|
||||
|
||||
let extraSharedClassDeclaration = [{
|
||||
|
||||
// Return the number of inputs to this module
|
||||
unsigned getNumInputs() {
|
||||
return getInputNames().size();
|
||||
/// Return the total number of ports in the module
|
||||
size_t getNumPorts() {
|
||||
return $_op.getHWModuleType().getNumPorts();
|
||||
}
|
||||
|
||||
// Return the number of outputs from this module
|
||||
unsigned getNumOutputs() {
|
||||
return getOutputNames().size();
|
||||
/// Return the total number of input and inout ports in the module
|
||||
size_t getNumInputs() {
|
||||
return $_op.getHWModuleType().getNumInputs();
|
||||
}
|
||||
|
||||
// Return the name of an input.
|
||||
mlir::StringRef getInputName(unsigned idx) {
|
||||
return getInputNameAttr(idx).getValue();
|
||||
/// Return the total number of output ports in the module
|
||||
size_t getNumOutputs() {
|
||||
return $_op.getHWModuleType().getNumOutputs();
|
||||
}
|
||||
|
||||
// Return the name of an output
|
||||
mlir::StringRef getOutputName(unsigned idx) {
|
||||
return getOutputNameAttr(idx).getValue();
|
||||
/// Return the set of names on input and inout ports
|
||||
SmallVector<StringAttr> getInputNames() {
|
||||
return $_op.getHWModuleType().getInputNames();
|
||||
}
|
||||
|
||||
// Return the name of an input to this module
|
||||
mlir::StringAttr getInputNameAttr(unsigned idx) {
|
||||
return cast<StringAttr>(getInputNames()[idx]);
|
||||
/// Return the set of names on output ports
|
||||
SmallVector<StringAttr> getOutputNames() {
|
||||
return $_op.getHWModuleType().getOutputNames();
|
||||
}
|
||||
|
||||
// Get the name for the specified input or inout port
|
||||
StringRef getInputName(size_t idx) {
|
||||
return $_op.getHWModuleType().getInputName(idx);
|
||||
}
|
||||
|
||||
// Return the name of an output to this module
|
||||
mlir::StringAttr getOutputNameAttr(unsigned idx) {
|
||||
return cast<StringAttr>(getOutputNames()[idx]);
|
||||
// Get the name for the specified output port
|
||||
StringRef getOutputName(size_t idx) {
|
||||
return $_op.getHWModuleType().getOutputName(idx);
|
||||
}
|
||||
|
||||
StringAttr getInputNameAttr(size_t idx) {
|
||||
return $_op.getHWModuleType().getInputNameAttr(idx);
|
||||
}
|
||||
|
||||
StringAttr getOutputNameAttr(size_t idx) {
|
||||
return $_op.getHWModuleType().getOutputNameAttr(idx);
|
||||
}
|
||||
|
||||
|
||||
}];
|
||||
|
||||
let verify = [{
|
||||
|
|
|
@ -67,9 +67,9 @@ class HWModuleOpBase<string mnemonic, list<Trait> traits = []> :
|
|||
code extraModuleClassDefinition = [{}];
|
||||
|
||||
let extraClassDefinition = extraModuleClassDefinition # [{
|
||||
size_t $cppClass::getNumPorts() {
|
||||
// This is wildly inefficient
|
||||
return getPortList().size();
|
||||
|
||||
ModuleType $cppClass::getHWModuleType() {
|
||||
return detail::fnToMod(getFunctionType(), getArgNames(), getResultNames());
|
||||
}
|
||||
|
||||
::circt::hw::InnerSymAttr $cppClass::getPortSymbolAttr(size_t portIndex) {
|
||||
|
|
|
@ -241,6 +241,7 @@ def ModuleTypeImpl : HWType<"Module"> {
|
|||
let extraClassDeclaration = [{
|
||||
// Many of these are transitional and will be removed when modules and instances
|
||||
// have moved over to this type.
|
||||
size_t getNumPorts();
|
||||
size_t getNumInputs();
|
||||
size_t getNumOutputs();
|
||||
SmallVector<Type> getInputTypes();
|
||||
|
|
|
@ -179,10 +179,9 @@ def MSFTModuleOp : MSFTModuleOpBase<"module",
|
|||
}];
|
||||
|
||||
let extraClassDefinition = [{
|
||||
size_t $cppClass::getNumPorts() {
|
||||
// This is excessively expensive
|
||||
auto ports = getPortList();
|
||||
return ports.size();
|
||||
|
||||
::circt::hw::ModuleType $cppClass::getHWModuleType() {
|
||||
return ::circt::hw::detail::fnToMod(getFunctionType(), getArgNames(), getResultNames());
|
||||
}
|
||||
|
||||
circt::hw::InnerSymAttr $cppClass::getPortSymbolAttr(size_t portIndex) {
|
||||
|
@ -226,6 +225,10 @@ def MSFTModuleExternOp : MSFTOp<"module.extern",
|
|||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
||||
/// Decode information about the input and output ports on this module.
|
||||
hw::ModulePortInfo getPorts();
|
||||
|
||||
/// Returns the argument types of this function.
|
||||
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
|
||||
|
||||
|
@ -237,6 +240,12 @@ def MSFTModuleExternOp : MSFTOp<"module.extern",
|
|||
//===------------------------------------------------------------------===//
|
||||
::circt::hw::ModulePortInfo getPortList();
|
||||
}];
|
||||
|
||||
let extraClassDefinition = [{
|
||||
::circt::hw::ModuleType $cppClass::getHWModuleType() {
|
||||
return ::circt::hw::detail::fnToMod(getFunctionType(), getArgNames(), getResultNames());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def DesignPartitionOp : MSFTOp<"partition",
|
||||
|
|
|
@ -530,16 +530,18 @@ LogicalResult ESIPureModuleOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
hw::ModulePortInfo ESIPureModuleOp::getPortList() {
|
||||
return hw::ModulePortInfo({});
|
||||
hw::ModuleType ESIPureModuleOp::getHWModuleType() {
|
||||
return hw::ModuleType::get(getContext(), {});
|
||||
}
|
||||
|
||||
size_t ESIPureModuleOp::getNumPorts() { return 0; }
|
||||
hw::InnerSymAttr ESIPureModuleOp::getPortSymbolAttr(size_t portIndex) {
|
||||
assert(false);
|
||||
::circt::hw::InnerSymAttr ESIPureModuleOp::getPortSymbolAttr(size_t) {
|
||||
return {};
|
||||
}
|
||||
|
||||
::circt::hw::ModulePortInfo ESIPureModuleOp::getPortList() {
|
||||
return hw::ModulePortInfo(ArrayRef<hw::PortInfo>{});
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "circt/Dialect/ESI/ESI.cpp.inc"
|
||||
|
||||
|
|
|
@ -615,13 +615,19 @@ ESIConnectServicesPass::surfaceReqs(hw::HWMutableModuleLike mod,
|
|||
// Create a replacement instance of the same operation type.
|
||||
SmallVector<NamedAttribute> newAttrs;
|
||||
for (auto attr : inst->getAttrs()) {
|
||||
if (attr.getName() == argsAttrName)
|
||||
newAttrs.push_back(b.getNamedAttr(argsAttrName, mod.getInputNames()));
|
||||
else if (attr.getName() == resultsAttrName)
|
||||
if (attr.getName() == argsAttrName) {
|
||||
auto names = mod.getInputNames();
|
||||
SmallVector<Attribute> nnames(names.begin(), names.end());
|
||||
newAttrs.push_back(
|
||||
b.getNamedAttr(resultsAttrName, mod.getOutputNames()));
|
||||
else
|
||||
b.getNamedAttr(argsAttrName, b.getArrayAttr(nnames)));
|
||||
} else if (attr.getName() == resultsAttrName) {
|
||||
auto names = mod.getOutputNames();
|
||||
SmallVector<Attribute> nnames(names.begin(), names.end());
|
||||
newAttrs.push_back(
|
||||
b.getNamedAttr(resultsAttrName, b.getArrayAttr(nnames)));
|
||||
} else {
|
||||
newAttrs.push_back(attr);
|
||||
}
|
||||
}
|
||||
auto *newInst = b.insert(Operation::create(
|
||||
inst->getLoc(), inst->getName(), newResultTypes, newOperands,
|
||||
|
|
|
@ -110,7 +110,7 @@ static bool applyToPort(AnnotationSet annos, Operation *op, size_t portCount,
|
|||
}
|
||||
|
||||
bool AnnotationSet::applyToPort(FModuleLike op, size_t portNo) const {
|
||||
return ::applyToPort(*this, op.getOperation(), getNumPorts(op), portNo);
|
||||
return ::applyToPort(*this, op.getOperation(), op.getNumPorts(), portNo);
|
||||
}
|
||||
|
||||
bool AnnotationSet::applyToPort(MemOp op, size_t portNo) const {
|
||||
|
|
|
@ -221,7 +221,7 @@ DeclKind firrtl::getDeclarationKind(Value val) {
|
|||
}
|
||||
|
||||
size_t firrtl::getNumPorts(Operation *op) {
|
||||
if (auto module = dyn_cast<hw::HWModuleLike>(op))
|
||||
if (auto module = dyn_cast<FModuleLike>(op))
|
||||
return module.getNumPorts();
|
||||
return op->getNumResults();
|
||||
}
|
||||
|
@ -458,7 +458,7 @@ Block *CircuitOp::getBodyBlock() { return &getBody().front(); }
|
|||
|
||||
static SmallVector<PortInfo> getPortImpl(FModuleLike module) {
|
||||
SmallVector<PortInfo> results;
|
||||
for (unsigned i = 0, e = getNumPorts(module); i < e; ++i) {
|
||||
for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i) {
|
||||
results.push_back({module.getPortNameAttr(i), module.getPortType(i),
|
||||
module.getPortDirection(i), module.getPortSymbolAttr(i),
|
||||
module.getPortLocation(i),
|
||||
|
@ -523,7 +523,7 @@ static void insertPorts(FModuleLike op,
|
|||
ArrayRef<std::pair<unsigned, PortInfo>> ports) {
|
||||
if (ports.empty())
|
||||
return;
|
||||
unsigned oldNumArgs = getNumPorts(op);
|
||||
unsigned oldNumArgs = op.getNumPorts();
|
||||
unsigned newNumArgs = oldNumArgs + ports.size();
|
||||
|
||||
// Add direction markers and names for new ports.
|
||||
|
@ -604,7 +604,7 @@ static void erasePorts(FModuleLike op, const llvm::BitVector &portIndices) {
|
|||
ArrayRef<Attribute> portAnnos = op.getPortAnnotations();
|
||||
ArrayRef<Attribute> portSyms = op.getPortSymbols();
|
||||
ArrayRef<Attribute> portLocs = op.getPortLocations();
|
||||
auto numPorts = getNumPorts(op);
|
||||
auto numPorts = op.getNumPorts();
|
||||
(void)numPorts;
|
||||
assert(portDirections.size() == numPorts);
|
||||
assert(portNames.size() == numPorts);
|
||||
|
@ -1710,7 +1710,7 @@ void InstanceOp::build(OpBuilder &builder, OperationState &result,
|
|||
|
||||
// Gather the result types.
|
||||
SmallVector<Type> resultTypes;
|
||||
resultTypes.reserve(getNumPorts(module));
|
||||
resultTypes.reserve(module.getNumPorts());
|
||||
llvm::transform(
|
||||
module.getPortTypes(), std::back_inserter(resultTypes),
|
||||
[](Attribute typeAttr) { return cast<TypeAttr>(typeAttr).getValue(); });
|
||||
|
@ -1860,7 +1860,7 @@ LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
// Check that all the attribute arrays are the right length up front. This
|
||||
// lets us safely use the port name in error messages below.
|
||||
size_t numResults = getNumResults();
|
||||
size_t numExpected = getNumPorts(referencedModule);
|
||||
size_t numExpected = referencedModule.getNumPorts();
|
||||
if (numResults != numExpected) {
|
||||
return emitNote(emitOpError() << "has a wrong number of results; expected "
|
||||
<< numExpected << " but got " << numResults);
|
||||
|
|
|
@ -867,7 +867,7 @@ void Inliner::mapPortsToWires(StringRef prefix, InliningLevel &il,
|
|||
const DenseSet<Attribute> &localSymbols) {
|
||||
auto target = il.childModule;
|
||||
auto portInfo = target.getPorts();
|
||||
for (unsigned i = 0, e = getNumPorts(target); i < e; ++i) {
|
||||
for (unsigned i = 0, e = target.getNumPorts(); i < e; ++i) {
|
||||
auto arg = target.getArgument(i);
|
||||
// Get the type of the wire.
|
||||
auto type = type_cast<FIRRTLType>(arg.getType());
|
||||
|
@ -1331,7 +1331,7 @@ void Inliner::identifyNLAsTargetingOnlyModules() {
|
|||
referencedNLASyms.insert(sym.getAttr());
|
||||
};
|
||||
// Scan ports
|
||||
for (unsigned i = 0, e = getNumPorts(mod); i != e; ++i)
|
||||
for (unsigned i = 0, e = mod.getNumPorts(); i != e; ++i)
|
||||
scanAnnos(AnnotationSet::forPort(mod, i));
|
||||
|
||||
// Scan operations (and not the module itself):
|
||||
|
|
|
@ -76,6 +76,8 @@ void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
|
|||
for (unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
|
||||
hw::PortInfo port;
|
||||
port.name = getArgName(i);
|
||||
if (!port.name)
|
||||
port.name = StringAttr::get(getContext(), "in" + std::to_string(i));
|
||||
port.dir = circt::hw::ModulePort::Direction::Input;
|
||||
port.type = machineType.getInput(i);
|
||||
port.argNum = i;
|
||||
|
@ -85,6 +87,8 @@ void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
|
|||
for (unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
|
||||
hw::PortInfo port;
|
||||
port.name = getResName(i);
|
||||
if (!port.name)
|
||||
port.name = StringAttr::get(getContext(), "out" + std::to_string(i));
|
||||
port.dir = circt::hw::ModulePort::Direction::Output;
|
||||
port.type = machineType.getResult(i);
|
||||
port.argNum = i;
|
||||
|
@ -94,9 +98,9 @@ void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
|
|||
|
||||
ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
auto buildFuncType =
|
||||
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
||||
function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
[&](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
|
||||
function_interface_impl::VariadicFlag,
|
||||
std::string &) { return builder.getFunctionType(argTypes, results); };
|
||||
|
||||
return function_interface_impl::parseFunctionOp(
|
||||
parser, result, /*allowVariadic=*/false,
|
||||
|
|
|
@ -603,9 +603,19 @@ LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
|
|||
return success();
|
||||
}
|
||||
|
||||
size_t ModuleType::getNumInputs() { return getInputTypes().size(); }
|
||||
size_t ModuleType::getNumInputs() {
|
||||
return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
|
||||
return p.dir != ModulePort::Direction::Output;
|
||||
});
|
||||
}
|
||||
|
||||
size_t ModuleType::getNumOutputs() { return getOutputTypes().size(); }
|
||||
size_t ModuleType::getNumOutputs() {
|
||||
return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
|
||||
return p.dir == ModulePort::Direction::Output;
|
||||
});
|
||||
}
|
||||
|
||||
size_t ModuleType::getNumPorts() { return getPorts().size(); }
|
||||
|
||||
SmallVector<Type> ModuleType::getInputTypes() {
|
||||
SmallVector<Type> retval;
|
||||
|
@ -788,14 +798,28 @@ ModuleType circt::hw::detail::fnToMod(Operation *op, ArrayAttr inputNames,
|
|||
ModuleType circt::hw::detail::fnToMod(FunctionType fnty, ArrayAttr inputNames,
|
||||
ArrayAttr outputNames) {
|
||||
SmallVector<ModulePort> ports;
|
||||
for (auto [t, n] : llvm::zip(fnty.getInputs(), inputNames))
|
||||
if (auto iot = dyn_cast<hw::InOutType>(t))
|
||||
ports.push_back({cast<StringAttr>(n), iot.getElementType(),
|
||||
ModulePort::Direction::InOut});
|
||||
else
|
||||
ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
|
||||
for (auto [t, n] : llvm::zip(fnty.getResults(), outputNames))
|
||||
ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
|
||||
if (inputNames) {
|
||||
for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames))
|
||||
if (auto iot = dyn_cast<hw::InOutType>(t))
|
||||
ports.push_back({cast<StringAttr>(n), iot.getElementType(),
|
||||
ModulePort::Direction::InOut});
|
||||
else
|
||||
ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
|
||||
} else {
|
||||
for (auto t : fnty.getInputs())
|
||||
if (auto iot = dyn_cast<hw::InOutType>(t))
|
||||
ports.push_back(
|
||||
{{}, iot.getElementType(), ModulePort::Direction::InOut});
|
||||
else
|
||||
ports.push_back({{}, t, ModulePort::Direction::Input});
|
||||
}
|
||||
if (outputNames) {
|
||||
for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames))
|
||||
ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
|
||||
} else {
|
||||
for (auto t : fnty.getResults())
|
||||
ports.push_back({{}, t, ModulePort::Direction::Output});
|
||||
}
|
||||
return ModuleType::get(fnty.getContext(), ports);
|
||||
}
|
||||
|
||||
|
|
|
@ -965,10 +965,6 @@ hw::ModulePortInfo MSFTModuleExternOp::getPortList() {
|
|||
return hw::ModulePortInfo(inputs, outputs);
|
||||
}
|
||||
|
||||
size_t MSFTModuleExternOp::getNumPorts() {
|
||||
return getArgNames().size() + getResultNames().size();
|
||||
}
|
||||
|
||||
hw::InnerSymAttr MSFTModuleExternOp::getPortSymbolAttr(size_t) { return {}; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue