mirror of https://github.com/llvm/circt.git
[circt-reduce] Try reductions on promising ops first (#3562)
This commit is contained in:
parent
d1b5d33616
commit
93bc76e091
|
@ -9,7 +9,7 @@ firrtl.circuit "Foo" {
|
|||
firrtl.module @FooFooBar(in %x: !firrtl.uint<1>, out %y: !firrtl.uint<1>) {
|
||||
firrtl.connect %y, %x : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
}
|
||||
// CHECK: firrtl.module @FooFoo
|
||||
// CHECK-NOT: firrtl.module @FooFoo
|
||||
firrtl.module @FooFoo(in %x: !firrtl.uint<1>, out %y: !firrtl.uint<1>) {
|
||||
%x0_x, %x0_y = firrtl.instance x0 @FooFooFoo(in x: !firrtl.uint<1>, out y: !firrtl.uint<1>)
|
||||
%x1_x, %x1_y = firrtl.instance x1 @FooFooBar(in x: !firrtl.uint<1>, out y: !firrtl.uint<1>)
|
||||
|
@ -21,7 +21,7 @@ firrtl.circuit "Foo" {
|
|||
firrtl.module @FooBar(in %x: !firrtl.uint<1>, out %y: !firrtl.uint<1>) {
|
||||
firrtl.connect %y, %x : !firrtl.uint<1>, !firrtl.uint<1>
|
||||
}
|
||||
// CHECK: firrtl.extmodule @Foo
|
||||
// CHECK: firrtl.module @Foo
|
||||
firrtl.module @Foo(in %x: !firrtl.uint<1>, out %y: !firrtl.uint<1>) {
|
||||
%x0_x, %x0_y = firrtl.instance x0 @FooFoo(in x: !firrtl.uint<1>, out y: !firrtl.uint<1>)
|
||||
%x1_x, %x1_y = firrtl.instance x1 @FooBar(in x: !firrtl.uint<1>, out y: !firrtl.uint<1>)
|
||||
|
|
|
@ -66,6 +66,93 @@ private:
|
|||
SmallDenseMap<Operation *, SymbolUserMap, 2> userMaps;
|
||||
};
|
||||
|
||||
/// Utility to easily get the instantiated firrtl::FModuleOp or an empty
|
||||
/// optional in case another type of module is instantiated.
|
||||
static llvm::Optional<firrtl::FModuleOp>
|
||||
findInstantiatedModule(firrtl::InstanceOp instOp, SymbolCache &symbols) {
|
||||
auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
|
||||
auto moduleOp = dyn_cast<firrtl::FModuleOp>(
|
||||
instOp.getReferencedModule(symbols.getSymbolTable(tableOp)));
|
||||
return moduleOp ? llvm::Optional(moduleOp)
|
||||
: llvm::Optional<firrtl::FModuleOp>();
|
||||
}
|
||||
|
||||
/// Compute the number of operations in a module. Recursively add the number of
|
||||
/// operations in instantiated modules.
|
||||
/// @param countMultipleInstantiations: If a module is instantiated multiple
|
||||
/// times and this flag is false, count it only once (to better represent code
|
||||
/// size reduction rather than area reduction of the actual hardware).
|
||||
/// @param countElsewhereInstantiated: If a module is also instantiated in
|
||||
/// another subtree of the design then don't count it if this flag is false.
|
||||
static uint64_t computeTransitiveModuleSize(
|
||||
SmallVector<std::pair<firrtl::FModuleOp, uint64_t>> &modules,
|
||||
SmallVector<Operation *> &instances, bool countMultipleInstantiations,
|
||||
bool countElsewhereInstantiated) {
|
||||
std::sort(instances.begin(), instances.end());
|
||||
std::sort(modules.begin(), modules.end(),
|
||||
[](auto a, auto b) { return a.first < b.first; });
|
||||
|
||||
auto *end = modules.end();
|
||||
if (!countMultipleInstantiations)
|
||||
end = std::unique(modules.begin(), modules.end(),
|
||||
[](auto a, auto b) { return a.first == b.first; });
|
||||
|
||||
uint64_t totalOperations = 0;
|
||||
|
||||
for (auto *iter = modules.begin(); iter != end; ++iter) {
|
||||
auto moduleOp = iter->first;
|
||||
|
||||
auto allInstancesCovered = [&]() {
|
||||
return llvm::all_of(
|
||||
moduleOp.getSymbolUses(moduleOp->getParentOfType<ModuleOp>()).value(),
|
||||
[&](auto symbolUse) {
|
||||
return std::binary_search(instances.begin(), instances.end(),
|
||||
symbolUse.getUser());
|
||||
});
|
||||
};
|
||||
|
||||
if (countElsewhereInstantiated || allInstancesCovered())
|
||||
totalOperations += iter->second;
|
||||
}
|
||||
|
||||
return totalOperations;
|
||||
}
|
||||
|
||||
static LogicalResult collectInstantiatedModules(
|
||||
llvm::Optional<firrtl::FModuleOp> fmoduleOp, SymbolCache &symbols,
|
||||
SmallVector<std::pair<firrtl::FModuleOp, uint64_t>> &modules,
|
||||
SmallVector<Operation *> &instances) {
|
||||
if (!fmoduleOp)
|
||||
return failure();
|
||||
|
||||
uint64_t opCount = 0;
|
||||
WalkResult result = fmoduleOp.value().walk([&](Operation *op) {
|
||||
if (auto instOp = dyn_cast<firrtl::InstanceOp>(op)) {
|
||||
auto moduleOp = findInstantiatedModule(instOp, symbols);
|
||||
if (!moduleOp) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "- `" << fmoduleOp.value().moduleName()
|
||||
<< "` recursively instantiated non-FIRRTL module.\n");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
if (failed(collectInstantiatedModules(moduleOp, symbols, modules,
|
||||
instances)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
instances.push_back(instOp);
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (result.wasInterrupted())
|
||||
return failure();
|
||||
|
||||
modules.push_back(std::make_pair(fmoduleOp.value(), opCount));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Check that all connections to a value are invalids.
|
||||
static bool onlyInvalidated(Value arg) {
|
||||
return llvm::all_of(arg.getUses(), [](OpOperand &use) {
|
||||
|
@ -167,7 +254,7 @@ PassReduction::PassReduction(MLIRContext *context, std::unique_ptr<Pass> pass,
|
|||
pm->addPass(std::move(pass));
|
||||
}
|
||||
|
||||
bool PassReduction::match(Operation *op) {
|
||||
uint64_t PassReduction::match(Operation *op) {
|
||||
return op->getName() == pm->getOpName(*context);
|
||||
}
|
||||
|
||||
|
@ -181,10 +268,26 @@ std::string PassReduction::getName() const { return passName.str(); }
|
|||
|
||||
/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
|
||||
struct ModuleExternalizer : public Reduction {
|
||||
void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
|
||||
void beforeReduction(mlir::ModuleOp op) override {
|
||||
nlaRemover.clear();
|
||||
symbols.clear();
|
||||
}
|
||||
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
|
||||
|
||||
bool match(Operation *op) override { return isa<firrtl::FModuleOp>(op); }
|
||||
uint64_t match(Operation *op) override {
|
||||
if (auto fmoduleOp = dyn_cast<firrtl::FModuleOp>(op)) {
|
||||
SmallVector<std::pair<firrtl::FModuleOp, uint64_t>> modules;
|
||||
SmallVector<Operation *> instances;
|
||||
if (failed(collectInstantiatedModules(fmoduleOp, symbols, modules,
|
||||
instances)))
|
||||
return 0;
|
||||
return computeTransitiveModuleSize(modules, instances,
|
||||
/*countMultipleInstantiations=*/false,
|
||||
/*countElsewhereInstantiated=*/true);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
LogicalResult rewrite(Operation *op) override {
|
||||
auto module = cast<firrtl::FModuleOp>(op);
|
||||
nlaRemover.markNLAsInOperation(op);
|
||||
|
@ -196,8 +299,10 @@ struct ModuleExternalizer : public Reduction {
|
|||
module->erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
std::string getName() const override { return "module-externalizer"; }
|
||||
|
||||
SymbolCache symbols;
|
||||
NLARemover nlaRemover;
|
||||
};
|
||||
|
||||
|
@ -351,7 +456,20 @@ struct InstanceStubber : public Reduction {
|
|||
nlaRemover.remove(op);
|
||||
}
|
||||
|
||||
bool match(Operation *op) override { return isa<firrtl::InstanceOp>(op); }
|
||||
uint64_t match(Operation *op) override {
|
||||
if (auto instOp = dyn_cast<firrtl::InstanceOp>(op)) {
|
||||
auto fmoduleOp = findInstantiatedModule(instOp, symbols);
|
||||
SmallVector<std::pair<firrtl::FModuleOp, uint64_t>> modules;
|
||||
SmallVector<Operation *> instances;
|
||||
if (failed(collectInstantiatedModules(fmoduleOp, symbols, modules,
|
||||
instances)))
|
||||
return 0;
|
||||
return computeTransitiveModuleSize(modules, instances,
|
||||
/*countMultipleInstantiations=*/false,
|
||||
/*countElsewhereInstantiated=*/false);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
LogicalResult rewrite(Operation *op) override {
|
||||
auto instOp = cast<firrtl::InstanceOp>(op);
|
||||
|
@ -398,7 +516,7 @@ struct InstanceStubber : public Reduction {
|
|||
struct MemoryStubber : public Reduction {
|
||||
void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
|
||||
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
|
||||
bool match(Operation *op) override { return isa<firrtl::MemOp>(op); }
|
||||
uint64_t match(Operation *op) override { return isa<firrtl::MemOp>(op); }
|
||||
LogicalResult rewrite(Operation *op) override {
|
||||
auto memOp = cast<firrtl::MemOp>(op);
|
||||
LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
|
||||
|
@ -491,12 +609,12 @@ static bool isFlowSensitiveOp(Operation *op) {
|
|||
/// rapidly.
|
||||
template <unsigned OpNum>
|
||||
struct OperandForwarder : public Reduction {
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
|
||||
OpNum >= op->getNumOperands())
|
||||
return false;
|
||||
return 0;
|
||||
if (isFlowSensitiveOp(op))
|
||||
return false;
|
||||
return 0;
|
||||
auto resultTy = op->getResult(0).getType().dyn_cast<firrtl::FIRRTLType>();
|
||||
auto opTy = op->getOperand(OpNum).getType().dyn_cast<firrtl::FIRRTLType>();
|
||||
return resultTy && opTy &&
|
||||
|
@ -535,11 +653,11 @@ struct OperandForwarder : public Reduction {
|
|||
/// A sample reduction pattern that replaces operations with a constant zero of
|
||||
/// their type.
|
||||
struct Constantifier : public Reduction {
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
if (op->getNumResults() != 1 || op->getNumOperands() == 0)
|
||||
return false;
|
||||
return 0;
|
||||
if (isFlowSensitiveOp(op))
|
||||
return false;
|
||||
return 0;
|
||||
auto type = op->getResult(0).getType().dyn_cast<firrtl::FIRRTLType>();
|
||||
return type && type.isa<firrtl::UIntType, firrtl::SIntType>();
|
||||
}
|
||||
|
@ -564,7 +682,7 @@ struct Constantifier : public Reduction {
|
|||
/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
|
||||
/// connects and creates opportunities for reduction in DCE/CSE.
|
||||
struct ConnectInvalidator : public Reduction {
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
return isa<firrtl::ConnectOp, firrtl::StrictConnectOp>(op) &&
|
||||
op->getOperand(1).getType().cast<firrtl::FIRRTLType>().isPassive() &&
|
||||
!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>();
|
||||
|
@ -589,7 +707,7 @@ struct ConnectInvalidator : public Reduction {
|
|||
/// results or their results have no users.
|
||||
struct OperationPruner : public Reduction {
|
||||
void beforeReduction(mlir::ModuleOp op) override { symbols.clear(); }
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
return !isa<ModuleOp>(op) &&
|
||||
(op->getNumResults() == 0 || op->use_empty()) &&
|
||||
(!op->hasAttr(SymbolTable::getSymbolAttrName()) ||
|
||||
|
@ -610,7 +728,7 @@ struct OperationPruner : public Reduction {
|
|||
struct AnnotationRemover : public Reduction {
|
||||
void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
|
||||
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
return op->hasAttr("annotations") || op->hasAttr("portAnnotations");
|
||||
}
|
||||
LogicalResult rewrite(Operation *op) override {
|
||||
|
@ -637,13 +755,13 @@ struct AnnotationRemover : public Reduction {
|
|||
/// A sample reduction pattern that removes ports from the root `firrtl.module`
|
||||
/// if the port is not used or just invalidated.
|
||||
struct RootPortPruner : public Reduction {
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
auto module = dyn_cast<firrtl::FModuleOp>(op);
|
||||
if (!module)
|
||||
return false;
|
||||
return 0;
|
||||
auto circuit = module->getParentOfType<firrtl::CircuitOp>();
|
||||
if (!circuit)
|
||||
return false;
|
||||
return 0;
|
||||
return circuit.getNameAttr() == module.getNameAttr();
|
||||
}
|
||||
LogicalResult rewrite(Operation *op) override {
|
||||
|
@ -673,11 +791,11 @@ struct ExtmoduleInstanceRemover : public Reduction {
|
|||
}
|
||||
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
|
||||
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
|
||||
return isa<firrtl::FExtModuleOp>(
|
||||
instOp.getReferencedModule(symbols.getNearestSymbolTable(instOp)));
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
LogicalResult rewrite(Operation *op) override {
|
||||
auto instOp = cast<firrtl::InstanceOp>(op);
|
||||
|
@ -717,20 +835,20 @@ struct ConnectForwarder : public Reduction {
|
|||
op->erase();
|
||||
}
|
||||
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
if (!isa<firrtl::FConnectLike>(op))
|
||||
return false;
|
||||
return 0;
|
||||
auto dest = op->getOperand(0);
|
||||
auto src = op->getOperand(1);
|
||||
auto *destOp = dest.getDefiningOp();
|
||||
auto *srcOp = src.getDefiningOp();
|
||||
if (dest == src)
|
||||
return false;
|
||||
return 0;
|
||||
|
||||
// Ensure that the destination is something we should be able to forward
|
||||
// through.
|
||||
if (!isa_and_nonnull<firrtl::WireOp>(destOp))
|
||||
return false;
|
||||
return 0;
|
||||
|
||||
// Ensure that the destination is connected to only once, and all uses of
|
||||
// the connection occur after the definition of the source.
|
||||
|
@ -739,14 +857,14 @@ struct ConnectForwarder : public Reduction {
|
|||
auto *op = use.getOwner();
|
||||
if (use.getOperandNumber() == 0 && isa<firrtl::FConnectLike>(op)) {
|
||||
if (++numConnects > 1)
|
||||
return false;
|
||||
return 0;
|
||||
continue;
|
||||
}
|
||||
if (srcOp && !srcOp->isBeforeInBlock(op))
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
return 1;
|
||||
}
|
||||
|
||||
LogicalResult rewrite(Operation *op) override {
|
||||
|
@ -766,20 +884,20 @@ struct ConnectForwarder : public Reduction {
|
|||
/// an operand of the source value of the connection.
|
||||
template <unsigned OpNum>
|
||||
struct ConnectSourceOperandForwarder : public Reduction {
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
if (!isa<firrtl::ConnectOp, firrtl::StrictConnectOp>(op))
|
||||
return false;
|
||||
return 0;
|
||||
auto dest = op->getOperand(0);
|
||||
auto *destOp = dest.getDefiningOp();
|
||||
|
||||
// Ensure that the destination is used only once.
|
||||
if (!destOp || !destOp->hasOneUse() ||
|
||||
!isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(destOp))
|
||||
return false;
|
||||
return 0;
|
||||
|
||||
auto *srcOp = op->getOperand(1).getDefiningOp();
|
||||
if (!srcOp || OpNum >= srcOp->getNumOperands())
|
||||
return false;
|
||||
return 0;
|
||||
|
||||
auto resultTy = dest.getType().dyn_cast<firrtl::FIRRTLType>();
|
||||
auto opTy =
|
||||
|
@ -841,7 +959,7 @@ struct DetachSubaccesses : public Reduction {
|
|||
for (auto *op : opsToErase)
|
||||
op->erase();
|
||||
}
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
// Only applies to wires and registers that are purely used in subaccess
|
||||
// operations.
|
||||
return isa<firrtl::WireOp, firrtl::RegOp, firrtl::RegResetOp>(op) &&
|
||||
|
@ -883,11 +1001,11 @@ struct DetachSubaccesses : public Reduction {
|
|||
/// normal canonicalizations.
|
||||
struct NodeSymbolRemover : public Reduction {
|
||||
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
if (auto nodeOp = dyn_cast<firrtl::NodeOp>(op))
|
||||
return nodeOp.getInnerSym() &&
|
||||
!nodeOp.getInnerSym()->getSymName().getValue().empty();
|
||||
return false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
LogicalResult rewrite(Operation *op) override {
|
||||
|
@ -907,14 +1025,14 @@ struct EagerInliner : public Reduction {
|
|||
}
|
||||
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
|
||||
|
||||
bool match(Operation *op) override {
|
||||
uint64_t match(Operation *op) override {
|
||||
auto instOp = dyn_cast<firrtl::InstanceOp>(op);
|
||||
if (!instOp)
|
||||
return false;
|
||||
return 0;
|
||||
auto tableOp = SymbolTable::getNearestSymbolTable(instOp);
|
||||
auto moduleOp = instOp.getReferencedModule(symbols.getSymbolTable(tableOp));
|
||||
if (!isa<firrtl::FModuleOp>(moduleOp))
|
||||
return false;
|
||||
return 0;
|
||||
return symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -48,8 +48,11 @@ struct Reduction {
|
|||
/// reductions before the resulting module is tried for interestingness.
|
||||
virtual void afterReduction(mlir::ModuleOp) {}
|
||||
|
||||
/// Check if the reduction can apply to a specific operation.
|
||||
virtual bool match(mlir::Operation *op) = 0;
|
||||
/// Check if the reduction can apply to a specific operation. Returns a
|
||||
/// benefit measure where a higher number means that applying the pattern
|
||||
/// leads to a bigger reduction and zero means that the patten does not
|
||||
/// match and thus cannot be applied at all.
|
||||
virtual uint64_t match(mlir::Operation *op) = 0;
|
||||
|
||||
/// Apply the reduction to a specific operation. If the returned result
|
||||
/// indicates that the application failed, the resulting module is treated the
|
||||
|
@ -87,7 +90,7 @@ struct Reduction {
|
|||
struct PassReduction : public Reduction {
|
||||
PassReduction(mlir::MLIRContext *context, std::unique_ptr<mlir::Pass> pass,
|
||||
bool canIncreaseSize = false, bool oneShot = false);
|
||||
bool match(mlir::Operation *op) override;
|
||||
uint64_t match(mlir::Operation *op) override;
|
||||
mlir::LogicalResult rewrite(mlir::Operation *op) override;
|
||||
std::string getName() const override;
|
||||
bool acceptSizeIncrease() const override { return canIncreaseSize; }
|
||||
|
|
|
@ -226,14 +226,22 @@ static LogicalResult execute(MLIRContext &context) {
|
|||
size_t opIdx = 0;
|
||||
mlir::OwningOpRef<mlir::ModuleOp> newModule = module->clone();
|
||||
pattern.beforeReduction(*newModule);
|
||||
SmallVector<std::pair<Operation *, uint64_t>, 16> opBenefits;
|
||||
newModule->walk([&](Operation *op) {
|
||||
if (!pattern.match(op))
|
||||
return;
|
||||
auto i = opIdx++;
|
||||
if (i < rangeBase || i - rangeBase >= rangeLength)
|
||||
return;
|
||||
(void)pattern.rewrite(op);
|
||||
uint64_t benefit = pattern.match(op);
|
||||
if (benefit > 0) {
|
||||
opIdx++;
|
||||
opBenefits.push_back(std::make_pair(op, benefit));
|
||||
}
|
||||
});
|
||||
std::sort(opBenefits.begin(), opBenefits.end(),
|
||||
[](auto a, auto b) { return a.second > b.second; });
|
||||
for (size_t i = rangeBase;
|
||||
i < rangeBase + rangeLength && i < opBenefits.size(); i++) {
|
||||
auto *op = opBenefits[i].first;
|
||||
if (pattern.match(op))
|
||||
(void)pattern.rewrite(op);
|
||||
}
|
||||
pattern.afterReduction(*newModule);
|
||||
if (opIdx == 0) {
|
||||
VERBOSE({
|
||||
|
|
Loading…
Reference in New Issue