[FIRRTL] Use BitVector over ArrayRef/SmallVector for eraseArguments. (#3837)

* [FIRRTL] Use BitVector over ArrayRef/SmallVector for eraseArguments.

Upstream MLIR removed an overload of `eraseArguments()` in
27e8ee208cb2142514ee2e3ab342dafaf6374f9e, stating that the overload
isn't useful because we should probably be using BitVector in most
cases.

This mostly affected code in the FIRRTL dialect that used SmallVector
for holding the list of port indices to delete, which indeed can be
replaced with a BitVector.

One thing to be careful about is that while `push_back()` is still
defined on BitVector, it has different behavior than
`SmallVector<unsigned>::push_back()`, in that the former only allows you
to push back a boolean 0 or 1 to the end of the bit vector, while the
latter pushes back an integer index. `BitVector::set()` matches the
previous behavior.

Co-authored-by: Will Dietz <will.dietz@sifive.com>
This commit is contained in:
Richard Xia 2022-09-08 14:54:19 -07:00 committed by GitHub
parent 4d92ac1aa2
commit f486947e15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 64 additions and 49 deletions

View File

@ -101,7 +101,7 @@ def InstanceOp : ReferableDeclOp<"instance", [HasParent<"firrtl::FModuleOp, firr
/// Builds a new `InstanceOp` with the ports listed in `portIndices` erased,
/// and updates any users of the remaining ports to point at the new
/// instance.
InstanceOp erasePorts(OpBuilder &builder, ArrayRef<unsigned> portIndices);
InstanceOp erasePorts(OpBuilder &builder, const llvm::BitVector &portIndices);
/// Clone the instance op and add ports. This is usually used in
/// conjuction with adding ports to the referenced module. This will emit

View File

@ -98,9 +98,8 @@ def FModuleOp : FIRRTLOp<"module", [IsolatedFromAbove, Symbol, SingleBlock,
/// Inserts the given ports.
void insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports);
/// Erases the ports listed in `portIndices`. `portIndices` is expected to
/// be in order and unique.
void erasePorts(ArrayRef<unsigned> portIndices);
/// Erases the ports that have their corresponding bit set in `portIndices`.
void erasePorts(const llvm::BitVector &portIndices);
void getAsmBlockArgumentNames(mlir::Region &region,
mlir::OpAsmSetValueNameFn setNameFn);

View File

@ -24,6 +24,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
@ -41,18 +42,17 @@ using namespace chirrtl;
// Utilities
//===----------------------------------------------------------------------===//
/// Remove elements at the specified indices from the input array, returning the
/// elements not mentioned. The indices array is expected to be sorted and
/// unique.
/// Remove elements from the input array corresponding to set bits in
/// `indicesToDrop`, returning the elements not mentioned.
template <typename T>
static SmallVector<T>
removeElementsAtIndices(ArrayRef<T> input, ArrayRef<unsigned> indicesToDrop) {
#ifndef NDEBUG // Check sortedness.
removeElementsAtIndices(ArrayRef<T> input,
const llvm::BitVector &indicesToDrop) {
#ifndef NDEBUG
if (!input.empty()) {
for (size_t i = 1, e = indicesToDrop.size(); i != e; ++i)
assert(indicesToDrop[i - 1] < indicesToDrop[i] &&
"indicesToDrop isn't sorted and unique");
assert(indicesToDrop.back() < input.size() && "index out of range");
int lastIndex = indicesToDrop.find_last();
if (lastIndex >= 0)
assert((size_t)lastIndex < input.size() && "index out of range");
}
#endif
@ -64,9 +64,9 @@ removeElementsAtIndices(ArrayRef<T> input, ArrayRef<unsigned> indicesToDrop) {
// Copy over the live chunks.
size_t lastCopied = 0;
SmallVector<T> result;
result.reserve(input.size() - indicesToDrop.size());
result.reserve(input.size() - indicesToDrop.count());
for (unsigned indexToDrop : indicesToDrop) {
for (unsigned indexToDrop : indicesToDrop.set_bits()) {
// If we skipped over some valid elements, copy them over.
if (indexToDrop > lastCopied) {
result.append(input.begin() + lastCopied, input.begin() + indexToDrop);
@ -574,10 +574,9 @@ void FMemModuleOp::insertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports) {
(*this).setPortSymbols(newSyms);
}
/// Erases the ports listed in `portIndices`. `portIndices` is expected to
/// be in order and unique.
void FModuleOp::erasePorts(ArrayRef<unsigned> portIndices) {
if (portIndices.empty())
/// Erases the ports that have their corresponding bit set in `portIndices`.
void FModuleOp::erasePorts(const llvm::BitVector &portIndices) {
if (portIndices.none())
return;
// Drop the direction markers for dead ports.
@ -1219,8 +1218,11 @@ void InstanceOp::build(OpBuilder &builder, OperationState &result,
/// Builds a new `InstanceOp` with the ports listed in `portIndices` erased, and
/// updates any users of the remaining ports to point at the new instance.
InstanceOp InstanceOp::erasePorts(OpBuilder &builder,
ArrayRef<unsigned> portIndices) {
if (portIndices.empty())
const llvm::BitVector &portIndices) {
assert(portIndices.size() >= getNumResults() &&
"portIndices is not at least as large as getNumResults()");
if (portIndices.none())
return *this;
SmallVector<Type> newResultTypes = removeElementsAtIndices<Type>(
@ -1237,10 +1239,9 @@ InstanceOp InstanceOp::erasePorts(OpBuilder &builder,
newPortDirections, newPortNames, getAnnotations().getValue(),
newPortAnnotations, getLowerToBind(), getInnerSymAttr());
SmallDenseSet<unsigned> portSet(portIndices.begin(), portIndices.end());
for (unsigned oldIdx = 0, newIdx = 0, numOldPorts = getNumResults();
oldIdx != numOldPorts; ++oldIdx) {
if (portSet.contains(oldIdx)) {
if (portIndices.test(oldIdx)) {
assert(getResult(oldIdx).use_empty() && "removed instance port has uses");
continue;
}

View File

@ -11,6 +11,7 @@
#include "circt/Dialect/FIRRTL/Passes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/TinyPtrVector.h"
#include "llvm/Support/Debug.h"
@ -370,8 +371,8 @@ void IMDeadCodeElimPass::rewriteModuleSignature(FModuleOp module) {
instanceGraph->lookup(module.moduleNameAttr());
LLVM_DEBUG(llvm::dbgs() << "Prune ports of module: " << module.getName()
<< "\n");
SmallVector<unsigned> deadPortIndexes;
unsigned numOldPorts = module.getNumPorts();
llvm::BitVector deadPortIndexes(numOldPorts);
ImplicitLocOpBuilder builder(module.getLoc(), module.getContext());
builder.setInsertionPointToStart(module.getBodyBlock());
@ -408,7 +409,7 @@ void IMDeadCodeElimPass::rewriteModuleSignature(FModuleOp module) {
liveSet.erase(argument);
liveSet.insert(wire);
argument.replaceAllUsesWith(wire);
deadPortIndexes.push_back(index);
deadPortIndexes.set(index);
continue;
}
@ -417,11 +418,11 @@ void IMDeadCodeElimPass::rewriteModuleSignature(FModuleOp module) {
WireOp wire = builder.create<WireOp>(argument.getType());
argument.replaceAllUsesWith(wire);
assert(isAssumedDead(wire) && "dummy wire must be dead");
deadPortIndexes.push_back(index);
deadPortIndexes.set(index);
}
// If there is nothing to remove, abort.
if (deadPortIndexes.empty())
if (deadPortIndexes.none())
return;
// Erase arguments of the old module from liveSet to prevent from creating
@ -446,7 +447,7 @@ void IMDeadCodeElimPass::rewriteModuleSignature(FModuleOp module) {
liveSet.erase(oldResult);
// Replace old instance results with dummy wires.
for (auto index : deadPortIndexes) {
for (auto index : deadPortIndexes.set_bits()) {
auto result = instance.getResult(index);
assert(isAssumedDead(result) &&
"instance results of dead ports must be dead");
@ -466,7 +467,7 @@ void IMDeadCodeElimPass::rewriteModuleSignature(FModuleOp module) {
instance.erase();
}
numRemovedPorts += deadPortIndexes.size();
numRemovedPorts += deadPortIndexes.count();
}
void IMDeadCodeElimPass::eraseEmptyModule(FModuleOp module) {

View File

@ -42,6 +42,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Parallel.h"
@ -919,7 +920,7 @@ bool TypeLoweringVisitor::visitDecl(FModuleOp module) {
lowerBlock(body);
// Lower the module block arguments.
SmallVector<unsigned> argsToRemove;
llvm::BitVector argsToRemove;
auto newArgs = module.getPorts();
for (size_t argIndex = 0, argsRemoved = 0; argIndex < newArgs.size();
++argIndex) {
@ -927,15 +928,17 @@ bool TypeLoweringVisitor::visitDecl(FModuleOp module) {
if (lowerArg(module, argIndex, argsRemoved, newArgs, lowerings)) {
auto arg = module.getArgument(argIndex);
processUsers(arg, lowerings);
argsToRemove.push_back(argIndex);
argsToRemove.push_back(true);
++argsRemoved;
}
} else
argsToRemove.push_back(false);
// lowerArg might have invalidated any reference to newArgs, be careful
}
// Remove block args that have been lowered.
body->eraseArguments(argsToRemove);
for (auto deadArg : llvm::reverse(argsToRemove))
for (auto deadArg = argsToRemove.find_last(); deadArg != -1;
deadArg = argsToRemove.find_prev(deadArg))
newArgs.erase(newArgs.begin() + deadArg);
SmallVector<NamedAttribute, 8> newModuleAttrs;

View File

@ -15,6 +15,7 @@
#include "circt/Dialect/FIRRTL/FIRRTLUtils.h"
#include "circt/Dialect/FIRRTL/Passes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/PostOrderIterator.h"
@ -135,9 +136,13 @@ class LowerXMRPass : public LowerXMRBase<LowerXMRPass> {
return signalPassFailure();
// Record all the RefType ports to be removed later.
for (size_t portNum = 0, e = module.getNumPorts(); portNum < e; ++portNum)
if (module.getPortType(portNum).isa<RefType>())
refPortsToRemoveMap[module].push_back(portNum);
size_t numPorts = module.getNumPorts();
for (size_t portNum = 0; portNum < numPorts; ++portNum)
if (module.getPortType(portNum).isa<RefType>()) {
if (refPortsToRemoveMap[module].size() < numPorts)
refPortsToRemoveMap[module].resize(numPorts);
refPortsToRemoveMap[module].set(portNum);
}
}
LLVM_DEBUG({
@ -201,14 +206,17 @@ class LowerXMRPass : public LowerXMRBase<LowerXMRPass> {
LogicalResult handleInstanceOp(InstanceOp inst) {
auto refMod = dyn_cast<FModuleOp>(inst.getReferencedModule());
bool multiplyInstantiated = !visitedModules.insert(refMod).second;
for (size_t portNum = 0, e = inst.getNumResults(); portNum < e; ++portNum) {
for (size_t portNum = 0, numPorts = inst.getNumResults();
portNum < numPorts; ++portNum) {
auto instanceResult = inst.getResult(portNum);
if (!instanceResult.getType().isa<RefType>())
continue;
if (!refMod)
return inst.emitOpError("cannot lower ext modules with RefType ports");
// Reference ports must be removed.
refPortsToRemoveMap[inst].push_back(portNum);
if (refPortsToRemoveMap[inst].size() < numPorts)
refPortsToRemoveMap[inst].resize(numPorts);
refPortsToRemoveMap[inst].set(portNum);
// Drop the dead-instance-ports.
if (instanceResult.use_empty())
continue;
@ -351,7 +359,7 @@ class LowerXMRPass : public LowerXMRBase<LowerXMRPass> {
llvm::EquivalenceClasses<Value, ValueComparator> dataFlowClasses;
// Instance and module ref ports that needs to be removed.
DenseMap<Operation *, SmallVector<unsigned>> refPortsToRemoveMap;
DenseMap<Operation *, llvm::BitVector> refPortsToRemoveMap;
/// RefResolve, RefSend, and Connects involving them that will be removed.
SmallVector<Operation *> opsToRemove;

View File

@ -13,6 +13,7 @@
#include "circt/Dialect/FIRRTL/Passes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/Support/Debug.h"
@ -53,12 +54,12 @@ void RemoveUnusedPortsPass::removeUnusedModulePorts(
FModuleOp module, InstanceGraphNode *instanceGraphNode) {
LLVM_DEBUG(llvm::dbgs() << "Prune ports of module: " << module.getName()
<< "\n");
// This tracks port indexes that can be erased.
SmallVector<unsigned> removalPortIndexes;
// This tracks constant values of output ports. None indicates an invalid
// value.
SmallVector<llvm::Optional<APSInt>> outputPortConstants;
auto ports = module.getPorts();
// This tracks port indexes that can be erased.
llvm::BitVector removalPortIndexes(ports.size());
for (const auto &e : llvm::enumerate(ports)) {
unsigned index = e.index();
@ -126,16 +127,16 @@ void RemoveUnusedPortsPass::removeUnusedModulePorts(
}
}
removalPortIndexes.push_back(index);
removalPortIndexes.set(index);
}
// If there is nothing to remove, abort.
if (removalPortIndexes.empty())
if (removalPortIndexes.none())
return;
// Delete ports from the module.
module.erasePorts(removalPortIndexes);
LLVM_DEBUG(llvm::for_each(removalPortIndexes, [&](unsigned index) {
LLVM_DEBUG(llvm::for_each(removalPortIndexes.set_bits(), [&](unsigned index) {
llvm::dbgs() << "Delete port: " << ports[index].name << "\n";
}););
@ -144,7 +145,7 @@ void RemoveUnusedPortsPass::removeUnusedModulePorts(
auto instance = ::cast<InstanceOp>(*use->getInstance());
ImplicitLocOpBuilder builder(instance.getLoc(), instance);
unsigned outputPortIndex = 0;
for (auto index : removalPortIndexes) {
for (auto index : removalPortIndexes.set_bits()) {
auto result = instance.getResult(index);
assert(!ports[index].isInOut() && "don't expect inout ports");
@ -195,7 +196,7 @@ void RemoveUnusedPortsPass::removeUnusedModulePorts(
instance.erase();
}
numRemovedPorts += removalPortIndexes.size();
numRemovedPorts += removalPortIndexes.count();
}
std::unique_ptr<mlir::Pass>

View File

@ -23,6 +23,7 @@
#include "mlir/Support/FileUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
@ -772,10 +773,11 @@ struct RootPortPruner : public Reduction {
LogicalResult rewrite(Operation *op) override {
assert(match(op));
auto module = cast<firrtl::FModuleOp>(op);
SmallVector<unsigned> dropPorts;
for (unsigned i = 0, e = module.getNumPorts(); i != e; ++i) {
size_t numPorts = module.getNumPorts();
llvm::BitVector dropPorts(numPorts);
for (unsigned i = 0; i != numPorts; ++i) {
if (onlyInvalidated(module.getArgument(i))) {
dropPorts.push_back(i);
dropPorts.set(i);
for (auto *user :
llvm::make_early_inc_range(module.getArgument(i).getUsers()))
user->erase();