[circt-reduce] Try reductions on promising ops first (#3562)

This commit is contained in:
Martin Erhart 2022-07-19 18:57:53 +02:00 committed by GitHub
parent d1b5d33616
commit 93bc76e091
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 176 additions and 47 deletions

View File

@ -9,7 +9,7 @@ firrtl.circuit "Foo" {
firrtl.module @FooFooBar(in %x: !firrtl.uint<1>, out %y: !firrtl.uint<1>) {
firrtl.connect %y, %x : !firrtl.uint<1>, !firrtl.uint<1>
}
// CHECK: firrtl.module @FooFoo
// CHECK-NOT: firrtl.module @FooFoo
firrtl.module @FooFoo(in %x: !firrtl.uint<1>, out %y: !firrtl.uint<1>) {
%x0_x, %x0_y = firrtl.instance x0 @FooFooFoo(in x: !firrtl.uint<1>, out y: !firrtl.uint<1>)
%x1_x, %x1_y = firrtl.instance x1 @FooFooBar(in x: !firrtl.uint<1>, out y: !firrtl.uint<1>)
@ -21,7 +21,7 @@ firrtl.circuit "Foo" {
firrtl.module @FooBar(in %x: !firrtl.uint<1>, out %y: !firrtl.uint<1>) {
firrtl.connect %y, %x : !firrtl.uint<1>, !firrtl.uint<1>
}
// CHECK: firrtl.extmodule @Foo
// CHECK: firrtl.module @Foo
firrtl.module @Foo(in %x: !firrtl.uint<1>, out %y: !firrtl.uint<1>) {
%x0_x, %x0_y = firrtl.instance x0 @FooFoo(in x: !firrtl.uint<1>, out y: !firrtl.uint<1>)
%x1_x, %x1_y = firrtl.instance x1 @FooBar(in x: !firrtl.uint<1>, out y: !firrtl.uint<1>)

View File

@ -66,6 +66,93 @@ private:
SmallDenseMap<Operation *, SymbolUserMap, 2> userMaps;
};
/// Utility to easily get the instantiated firrtl::FModuleOp or an empty
/// optional in case another type of module is instantiated.
static llvm::Optional<firrtl::FModuleOp>
findInstantiatedModule(firrtl::InstanceOp instOp, SymbolCache &symbols) {
auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
auto moduleOp = dyn_cast<firrtl::FModuleOp>(
instOp.getReferencedModule(symbols.getSymbolTable(tableOp)));
return moduleOp ? llvm::Optional(moduleOp)
: llvm::Optional<firrtl::FModuleOp>();
}
/// Compute the number of operations in a module. Recursively add the number of
/// operations in instantiated modules.
/// @param countMultipleInstantiations: If a module is instantiated multiple
/// times and this flag is false, count it only once (to better represent code
/// size reduction rather than area reduction of the actual hardware).
/// @param countElsewhereInstantiated: If a module is also instantiated in
/// another subtree of the design then don't count it if this flag is false.
static uint64_t computeTransitiveModuleSize(
SmallVector<std::pair<firrtl::FModuleOp, uint64_t>> &modules,
SmallVector<Operation *> &instances, bool countMultipleInstantiations,
bool countElsewhereInstantiated) {
std::sort(instances.begin(), instances.end());
std::sort(modules.begin(), modules.end(),
[](auto a, auto b) { return a.first < b.first; });
auto *end = modules.end();
if (!countMultipleInstantiations)
end = std::unique(modules.begin(), modules.end(),
[](auto a, auto b) { return a.first == b.first; });
uint64_t totalOperations = 0;
for (auto *iter = modules.begin(); iter != end; ++iter) {
auto moduleOp = iter->first;
auto allInstancesCovered = [&]() {
return llvm::all_of(
moduleOp.getSymbolUses(moduleOp->getParentOfType<ModuleOp>()).value(),
[&](auto symbolUse) {
return std::binary_search(instances.begin(), instances.end(),
symbolUse.getUser());
});
};
if (countElsewhereInstantiated || allInstancesCovered())
totalOperations += iter->second;
}
return totalOperations;
}
static LogicalResult collectInstantiatedModules(
llvm::Optional<firrtl::FModuleOp> fmoduleOp, SymbolCache &symbols,
SmallVector<std::pair<firrtl::FModuleOp, uint64_t>> &modules,
SmallVector<Operation *> &instances) {
if (!fmoduleOp)
return failure();
uint64_t opCount = 0;
WalkResult result = fmoduleOp.value().walk([&](Operation *op) {
if (auto instOp = dyn_cast<firrtl::InstanceOp>(op)) {
auto moduleOp = findInstantiatedModule(instOp, symbols);
if (!moduleOp) {
LLVM_DEBUG(llvm::dbgs()
<< "- `" << fmoduleOp.value().moduleName()
<< "` recursively instantiated non-FIRRTL module.\n");
return WalkResult::interrupt();
}
if (failed(collectInstantiatedModules(moduleOp, symbols, modules,
instances)))
return WalkResult::interrupt();
instances.push_back(instOp);
}
return WalkResult::advance();
});
if (result.wasInterrupted())
return failure();
modules.push_back(std::make_pair(fmoduleOp.value(), opCount));
return success();
}
/// Check that all connections to a value are invalids.
static bool onlyInvalidated(Value arg) {
return llvm::all_of(arg.getUses(), [](OpOperand &use) {
@ -167,7 +254,7 @@ PassReduction::PassReduction(MLIRContext *context, std::unique_ptr<Pass> pass,
pm->addPass(std::move(pass));
}
bool PassReduction::match(Operation *op) {
uint64_t PassReduction::match(Operation *op) {
return op->getName() == pm->getOpName(*context);
}
@ -181,10 +268,26 @@ std::string PassReduction::getName() const { return passName.str(); }
/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
struct ModuleExternalizer : public Reduction {
void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
void beforeReduction(mlir::ModuleOp op) override {
nlaRemover.clear();
symbols.clear();
}
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
bool match(Operation *op) override { return isa<firrtl::FModuleOp>(op); }
uint64_t match(Operation *op) override {
if (auto fmoduleOp = dyn_cast<firrtl::FModuleOp>(op)) {
SmallVector<std::pair<firrtl::FModuleOp, uint64_t>> modules;
SmallVector<Operation *> instances;
if (failed(collectInstantiatedModules(fmoduleOp, symbols, modules,
instances)))
return 0;
return computeTransitiveModuleSize(modules, instances,
/*countMultipleInstantiations=*/false,
/*countElsewhereInstantiated=*/true);
}
return 0;
}
LogicalResult rewrite(Operation *op) override {
auto module = cast<firrtl::FModuleOp>(op);
nlaRemover.markNLAsInOperation(op);
@ -196,8 +299,10 @@ struct ModuleExternalizer : public Reduction {
module->erase();
return success();
}
std::string getName() const override { return "module-externalizer"; }
SymbolCache symbols;
NLARemover nlaRemover;
};
@ -351,7 +456,20 @@ struct InstanceStubber : public Reduction {
nlaRemover.remove(op);
}
bool match(Operation *op) override { return isa<firrtl::InstanceOp>(op); }
uint64_t match(Operation *op) override {
if (auto instOp = dyn_cast<firrtl::InstanceOp>(op)) {
auto fmoduleOp = findInstantiatedModule(instOp, symbols);
SmallVector<std::pair<firrtl::FModuleOp, uint64_t>> modules;
SmallVector<Operation *> instances;
if (failed(collectInstantiatedModules(fmoduleOp, symbols, modules,
instances)))
return 0;
return computeTransitiveModuleSize(modules, instances,
/*countMultipleInstantiations=*/false,
/*countElsewhereInstantiated=*/false);
}
return 0;
}
LogicalResult rewrite(Operation *op) override {
auto instOp = cast<firrtl::InstanceOp>(op);
@ -398,7 +516,7 @@ struct InstanceStubber : public Reduction {
struct MemoryStubber : public Reduction {
void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
bool match(Operation *op) override { return isa<firrtl::MemOp>(op); }
uint64_t match(Operation *op) override { return isa<firrtl::MemOp>(op); }
LogicalResult rewrite(Operation *op) override {
auto memOp = cast<firrtl::MemOp>(op);
LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
@ -491,12 +609,12 @@ static bool isFlowSensitiveOp(Operation *op) {
/// rapidly.
template <unsigned OpNum>
struct OperandForwarder : public Reduction {
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
OpNum >= op->getNumOperands())
return false;
return 0;
if (isFlowSensitiveOp(op))
return false;
return 0;
auto resultTy = op->getResult(0).getType().dyn_cast<firrtl::FIRRTLType>();
auto opTy = op->getOperand(OpNum).getType().dyn_cast<firrtl::FIRRTLType>();
return resultTy && opTy &&
@ -535,11 +653,11 @@ struct OperandForwarder : public Reduction {
/// A sample reduction pattern that replaces operations with a constant zero of
/// their type.
struct Constantifier : public Reduction {
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
if (op->getNumResults() != 1 || op->getNumOperands() == 0)
return false;
return 0;
if (isFlowSensitiveOp(op))
return false;
return 0;
auto type = op->getResult(0).getType().dyn_cast<firrtl::FIRRTLType>();
return type && type.isa<firrtl::UIntType, firrtl::SIntType>();
}
@ -564,7 +682,7 @@ struct Constantifier : public Reduction {
/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
/// connects and creates opportunities for reduction in DCE/CSE.
struct ConnectInvalidator : public Reduction {
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
return isa<firrtl::ConnectOp, firrtl::StrictConnectOp>(op) &&
op->getOperand(1).getType().cast<firrtl::FIRRTLType>().isPassive() &&
!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>();
@ -589,7 +707,7 @@ struct ConnectInvalidator : public Reduction {
/// results or their results have no users.
struct OperationPruner : public Reduction {
void beforeReduction(mlir::ModuleOp op) override { symbols.clear(); }
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
return !isa<ModuleOp>(op) &&
(op->getNumResults() == 0 || op->use_empty()) &&
(!op->hasAttr(SymbolTable::getSymbolAttrName()) ||
@ -610,7 +728,7 @@ struct OperationPruner : public Reduction {
struct AnnotationRemover : public Reduction {
void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
return op->hasAttr("annotations") || op->hasAttr("portAnnotations");
}
LogicalResult rewrite(Operation *op) override {
@ -637,13 +755,13 @@ struct AnnotationRemover : public Reduction {
/// A sample reduction pattern that removes ports from the root `firrtl.module`
/// if the port is not used or just invalidated.
struct RootPortPruner : public Reduction {
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
auto module = dyn_cast<firrtl::FModuleOp>(op);
if (!module)
return false;
return 0;
auto circuit = module->getParentOfType<firrtl::CircuitOp>();
if (!circuit)
return false;
return 0;
return circuit.getNameAttr() == module.getNameAttr();
}
LogicalResult rewrite(Operation *op) override {
@ -673,11 +791,11 @@ struct ExtmoduleInstanceRemover : public Reduction {
}
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
return isa<firrtl::FExtModuleOp>(
instOp.getReferencedModule(symbols.getNearestSymbolTable(instOp)));
return false;
return 0;
}
LogicalResult rewrite(Operation *op) override {
auto instOp = cast<firrtl::InstanceOp>(op);
@ -717,20 +835,20 @@ struct ConnectForwarder : public Reduction {
op->erase();
}
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
if (!isa<firrtl::FConnectLike>(op))
return false;
return 0;
auto dest = op->getOperand(0);
auto src = op->getOperand(1);
auto *destOp = dest.getDefiningOp();
auto *srcOp = src.getDefiningOp();
if (dest == src)
return false;
return 0;
// Ensure that the destination is something we should be able to forward
// through.
if (!isa_and_nonnull<firrtl::WireOp>(destOp))
return false;
return 0;
// Ensure that the destination is connected to only once, and all uses of
// the connection occur after the definition of the source.
@ -739,14 +857,14 @@ struct ConnectForwarder : public Reduction {
auto *op = use.getOwner();
if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
if (++numConnects > 1)
return false;
return 0;
continue;
}
if (srcOp && !srcOp->isBeforeInBlock(op))
return false;
return 0;
}
return true;
return 1;
}
LogicalResult rewrite(Operation *op) override {
@ -766,20 +884,20 @@ struct ConnectForwarder : public Reduction {
/// an operand of the source value of the connection.
template <unsigned OpNum>
struct ConnectSourceOperandForwarder : public Reduction {
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
if (!isa<firrtl::ConnectOp, firrtl::StrictConnectOp>(op))
return false;
return 0;
auto dest = op->getOperand(0);
auto *destOp = dest.getDefiningOp();
// Ensure that the destination is used only once.
if (!destOp || !destOp->hasOneUse() ||
!isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
return false;
return 0;
auto *srcOp = op->getOperand(1).getDefiningOp();
if (!srcOp || OpNum >= srcOp->getNumOperands())
return false;
return 0;
auto resultTy = dest.getType().dyn_cast<firrtl::FIRRTLType>();
auto opTy =
@ -841,7 +959,7 @@ struct DetachSubaccesses : public Reduction {
for (auto *op : opsToErase)
op->erase();
}
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
// Only applies to wires and registers that are purely used in subaccess
// operations.
return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
@ -883,11 +1001,11 @@ struct DetachSubaccesses : public Reduction {
/// normal canonicalizations.
struct NodeSymbolRemover : public Reduction {
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
if (auto nodeOp = dyn_cast<firrtl::NodeOp>(op))
return nodeOp.getInnerSym() &&
!nodeOp.getInnerSym()->getSymName().getValue().empty();
return false;
return 0;
}
LogicalResult rewrite(Operation *op) override {
@ -907,14 +1025,14 @@ struct EagerInliner : public Reduction {
}
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
bool match(Operation *op) override {
uint64_t match(Operation *op) override {
auto instOp = dyn_cast<firrtl::InstanceOp>(op);
if (!instOp)
return false;
return 0;
auto tableOp = SymbolTable::getNearestSymbolTable(instOp);
auto moduleOp = instOp.getReferencedModule(symbols.getSymbolTable(tableOp));
if (!isa<firrtl::FModuleOp>(moduleOp))
return false;
return 0;
return symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1;
}

View File

@ -48,8 +48,11 @@ struct Reduction {
/// reductions before the resulting module is tried for interestingness.
virtual void afterReduction(mlir::ModuleOp) {}
/// Check if the reduction can apply to a specific operation.
virtual bool match(mlir::Operation *op) = 0;
/// Check if the reduction can apply to a specific operation. Returns a
/// benefit measure where a higher number means that applying the pattern
/// leads to a bigger reduction and zero means that the patten does not
/// match and thus cannot be applied at all.
virtual uint64_t match(mlir::Operation *op) = 0;
/// Apply the reduction to a specific operation. If the returned result
/// indicates that the application failed, the resulting module is treated the
@ -87,7 +90,7 @@ struct Reduction {
struct PassReduction : public Reduction {
PassReduction(mlir::MLIRContext *context, std::unique_ptr<mlir::Pass> pass,
bool canIncreaseSize = false, bool oneShot = false);
bool match(mlir::Operation *op) override;
uint64_t match(mlir::Operation *op) override;
mlir::LogicalResult rewrite(mlir::Operation *op) override;
std::string getName() const override;
bool acceptSizeIncrease() const override { return canIncreaseSize; }

View File

@ -226,14 +226,22 @@ static LogicalResult execute(MLIRContext &context) {
size_t opIdx = 0;
mlir::OwningOpRef<mlir::ModuleOp> newModule = module->clone();
pattern.beforeReduction(*newModule);
SmallVector<std::pair<Operation *, uint64_t>, 16> opBenefits;
newModule->walk([&](Operation *op) {
if (!pattern.match(op))
return;
auto i = opIdx++;
if (i < rangeBase || i - rangeBase >= rangeLength)
return;
(void)pattern.rewrite(op);
uint64_t benefit = pattern.match(op);
if (benefit > 0) {
opIdx++;
opBenefits.push_back(std::make_pair(op, benefit));
}
});
std::sort(opBenefits.begin(), opBenefits.end(),
[](auto a, auto b) { return a.second > b.second; });
for (size_t i = rangeBase;
i < rangeBase + rangeLength && i < opBenefits.size(); i++) {
auto *op = opBenefits[i].first;
if (pattern.match(op))
(void)pattern.rewrite(op);
}
pattern.afterReduction(*newModule);
if (opIdx == 0) {
VERBOSE({