[MSFT] Make `PassCommon` public (#3493)

`PassCommon` contains `getAndSortModules`, which is generally useful. Make it general by implementing `HWModuleLike` and `HWInstanceLike` and targeting those OpInterfaces instead. Involves a bit of ugliness when converting to/from the OpInterfaces to concrete module ops, but hopefully that'll go away as we add more functionality to said OpInterfaces.
This commit is contained in:
John Demme 2022-07-08 15:53:26 -07:00 committed by GitHub
parent ce85204ca9
commit aa5c574d2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 122 additions and 66 deletions

View File

@ -23,12 +23,6 @@
#include <functional>
namespace circt {
namespace msft {
void registerMSFTPasses();
} // namespace msft
} // namespace circt
#include "circt/Dialect/MSFT/MSFTDialect.h.inc"
#include "circt/Dialect/MSFT/MSFTEnums.h.inc"

View File

@ -6,11 +6,14 @@
//
//===----------------------------------------------------------------------===//
include "circt/Dialect/HW/HWOpInterfaces.td"
def InstanceOp : MSFTOp<"instance", [
Symbol,
ParentOneOf<["MSFTModuleOp"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<HWInstanceLike>
]> {
let summary = "Instantiate a module";
@ -68,7 +71,8 @@ def MSFTModuleOp : MSFTOp<"module",
[IsolatedFromAbove, FunctionOpInterface, Symbol, RegionKindInterface,
HasParent<"mlir::ModuleOp">,
SingleBlockImplicitTerminator<"OutputOp">,
OpAsmOpInterface]>{
OpAsmOpInterface,
DeclareOpInterfaceMethods<HWModuleLike>]>{
let summary = "MSFT HW Module";
let description = [{
A lot like `hw.module`, but with a few differences:

View File

@ -0,0 +1,45 @@
//===- MSFTPasses.h - Common code for passes --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_DIALECT_MSFT_MSFTPASSES_H
#define CIRCT_DIALECT_MSFT_MSFTPASSES_H
#include "circt/Dialect/MSFT/MSFTOps.h"
#include "circt/Dialect/HW/HWOpInterfaces.h"
#include "circt/Dialect/HW/HWOps.h"
namespace circt {
namespace msft {
void registerMSFTPasses();
/// A set of methods which are broadly useful in a number of dialects.
struct PassCommon {
protected:
SymbolCache topLevelSyms;
DenseMap<Operation *, SmallVector<hw::HWInstanceLike, 1>>
moduleInstantiations;
LogicalResult verifyInstances(ModuleOp topMod);
// Find all the modules and use the partial order of the instantiation DAG
// to sort them. If we use this order when "bubbling" up operations, we
// guarantee one-pass completeness. As a side-effect, populate the module to
// instantiation sites mapping.
//
// Assumption (unchecked): there is not a cycle in the instantiation graph.
void getAndSortModules(ModuleOp topMod,
SmallVectorImpl<hw::HWModuleLike> &mods);
void getAndSortModulesVisitor(hw::HWModuleLike mod,
SmallVectorImpl<hw::HWModuleLike> &mods,
DenseSet<Operation *> &modsSeen);
};
} // namespace msft
} // namespace circt
#endif // CIRCT_DIALECT_MSFT_MSFTPASSES_H

View File

@ -23,7 +23,7 @@
#include "circt/Dialect/HW/HWPasses.h"
#include "circt/Dialect/Handshake/HandshakePasses.h"
#include "circt/Dialect/LLHD/Transforms/Passes.h"
#include "circt/Dialect/MSFT/MSFTDialect.h"
#include "circt/Dialect/MSFT/MSFTPasses.h"
#include "circt/Dialect/SV/SVPasses.h"
#include "circt/Dialect/Seq/SeqPasses.h"
#include "circt/Transforms/Passes.h"

View File

@ -9,6 +9,7 @@
#include "circt/Dialect/MSFT/MSFTAttributes.h"
#include "circt/Dialect/MSFT/MSFTDialect.h"
#include "circt/Dialect/MSFT/MSFTOps.h"
#include "circt/Dialect/MSFT/MSFTPasses.h"
#include "circt/Support/LLVM.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"

View File

@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/MSFT/MSFTPasses.h"
#include "circt/Dialect/HW/HWAttributes.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/HW/HWTypes.h"
@ -528,29 +529,22 @@ std::unique_ptr<Pass> createExportTclPass() {
//===----------------------------------------------------------------------===//
namespace {
struct PassCommon {
struct MSFTPassCommon : PassCommon {
protected:
SymbolCache topLevelSyms;
DenseMap<MSFTModuleOp, SmallVector<InstanceOp, 1>> moduleInstantiations;
LogicalResult verifyInstances(ModuleOp topMod);
// Find all the modules and use the partial order of the instantiation DAG
// to sort them. If we use this order when "bubbling" up operations, we
// guarantee one-pass completeness. As a side-effect, populate the module to
// instantiation sites mapping.
//
// Assumption (unchecked): there is not a cycle in the instantiation graph.
void getAndSortModules(ModuleOp topMod, SmallVectorImpl<MSFTModuleOp> &mods);
void getAndSortModulesVisitor(MSFTModuleOp mod,
SmallVectorImpl<MSFTModuleOp> &mods,
DenseSet<MSFTModuleOp> &modsSeen);
SmallVector<InstanceOp, 1> &updateInstances(
/// Update all the instantiations of 'mod' to match the port list. For any
/// output ports which survived, automatically map the result according to
/// `newToOldResultMap`. Calls 'getOperandsFunc' with the new instance op, the
/// old instance op, and expects the operand vector to return filled.
/// `getOperandsFunc` can (and often does) modify other operations. The update
/// call deletes the original instance op, so all references are invalidated
/// after this call.
SmallVector<InstanceOp, 1> updateInstances(
MSFTModuleOp mod, ArrayRef<unsigned> newToOldResultMap,
llvm::function_ref<void(InstanceOp, InstanceOp, SmallVectorImpl<Value> &)>
getOperandsFunc);
void getAndSortModules(ModuleOp topMod, SmallVectorImpl<MSFTModuleOp> &mods);
void bubbleWiresUp(MSFTModuleOp mod);
void dedupOutputs(MSFTModuleOp mod);
void sinkWiresDown(MSFTModuleOp mod);
@ -566,21 +560,21 @@ static bool isWireManipulationOp(Operation *op) {
hw::ConstantOp>(op);
}
/// Update all the instantiations of 'mod' to match the port list. For any
/// output ports which survived, automatically map the result according to
/// `newToOldResultMap`. Calls 'getOperandsFunc' with the new instance op, the
/// old instance op, and expects the operand vector to return filled.
/// `getOperandsFunc` can (and often does) modify other operations. The update
/// call deletes the original instance op, so all references are invalidated
/// after this call.
SmallVector<InstanceOp, 1> &PassCommon::updateInstances(
SmallVector<InstanceOp, 1> MSFTPassCommon::updateInstances(
MSFTModuleOp mod, ArrayRef<unsigned> newToOldResultMap,
llvm::function_ref<void(InstanceOp, InstanceOp, SmallVectorImpl<Value> &)>
getOperandsFunc) {
SmallVector<InstanceOp, 1> newInstances;
for (InstanceOp inst : moduleInstantiations[mod]) {
assert(inst->getParentOp());
SmallVector<hw::HWInstanceLike, 1> newInstances;
SmallVector<InstanceOp, 1> newMsftInstances;
for (hw::HWInstanceLike instLike : moduleInstantiations[mod]) {
assert(instLike->getParentOp());
auto inst = dyn_cast<InstanceOp>(instLike.getOperation());
if (!inst) {
instLike.emitWarning("Can not update hw.instance ops");
continue;
}
OpBuilder b(inst);
auto newInst = b.create<InstanceOp>(inst.getLoc(), mod.getResultTypes(),
inst.getOperands(), inst->getAttrs());
@ -595,41 +589,55 @@ SmallVector<InstanceOp, 1> &PassCommon::updateInstances(
.replaceAllUsesWith(newInst.getResult(oldResult.index()));
newInstances.push_back(newInst);
newMsftInstances.push_back(newInst);
inst->dropAllUses();
inst->erase();
}
moduleInstantiations[mod].swap(newInstances);
return moduleInstantiations[mod];
return newMsftInstances;
}
// Run a post-order DFS.
void PassCommon::getAndSortModulesVisitor(MSFTModuleOp mod,
SmallVectorImpl<MSFTModuleOp> &mods,
DenseSet<MSFTModuleOp> &modsSeen) {
void PassCommon::getAndSortModulesVisitor(
hw::HWModuleLike mod, SmallVectorImpl<hw::HWModuleLike> &mods,
DenseSet<Operation *> &modsSeen) {
if (modsSeen.contains(mod))
return;
modsSeen.insert(mod);
mod.walk([&](InstanceOp inst) {
Operation *modOp = topLevelSyms.getDefinition(inst.moduleNameAttr());
auto mod = dyn_cast_or_null<MSFTModuleOp>(modOp);
if (!mod)
return;
moduleInstantiations[mod].push_back(inst);
getAndSortModulesVisitor(mod, mods, modsSeen);
mod.walk([&](hw::HWInstanceLike inst) {
Operation *modOp =
topLevelSyms.getDefinition(inst.referencedModuleNameAttr());
assert(modOp);
moduleInstantiations[modOp].push_back(inst);
if (auto modLike = dyn_cast<hw::HWModuleLike>(modOp))
getAndSortModulesVisitor(modLike, mods, modsSeen);
});
mods.push_back(mod);
}
void MSFTPassCommon::getAndSortModules(ModuleOp topMod,
SmallVectorImpl<MSFTModuleOp> &mods) {
SmallVector<hw::HWModuleLike, 16> moduleLikes;
PassCommon::getAndSortModules(topMod, moduleLikes);
mods.clear();
for (auto modLike : moduleLikes) {
auto mod = dyn_cast<MSFTModuleOp>(modLike.getOperation());
if (mod)
mods.push_back(mod);
}
}
void PassCommon::getAndSortModules(ModuleOp topMod,
SmallVectorImpl<MSFTModuleOp> &mods) {
SmallVectorImpl<hw::HWModuleLike> &mods) {
// Add here _before_ we go deeper to prevent infinite recursion.
DenseSet<MSFTModuleOp> modsSeen;
DenseSet<Operation *> modsSeen;
mods.clear();
moduleInstantiations.clear();
topMod.walk(
[&](MSFTModuleOp mod) { getAndSortModulesVisitor(mod, mods, modsSeen); });
topMod.walk([&](hw::HWModuleLike mod) {
getAndSortModulesVisitor(mod, mods, modsSeen);
});
}
LogicalResult PassCommon::verifyInstances(mlir::ModuleOp mod) {
@ -647,7 +655,7 @@ LogicalResult PassCommon::verifyInstances(mlir::ModuleOp mod) {
}
namespace {
struct PartitionPass : public PartitionBase<PartitionPass>, PassCommon {
struct PartitionPass : public PartitionBase<PartitionPass>, MSFTPassCommon {
void runOnOperation() override;
private:
@ -1427,7 +1435,8 @@ std::unique_ptr<Pass> createPartitionPass() {
} // namespace circt
namespace {
struct WireCleanupPass : public WireCleanupBase<WireCleanupPass>, PassCommon {
struct WireCleanupPass : public WireCleanupBase<WireCleanupPass>,
MSFTPassCommon {
void runOnOperation() override;
};
} // anonymous namespace
@ -1455,7 +1464,7 @@ void WireCleanupPass::runOnOperation() {
}
/// Remove outputs driven by the same value.
void PassCommon::dedupOutputs(MSFTModuleOp mod) {
void MSFTPassCommon::dedupOutputs(MSFTModuleOp mod) {
Block *body = mod.getBodyBlock();
Operation *terminator = body->getTerminator();
@ -1487,7 +1496,7 @@ void PassCommon::dedupOutputs(MSFTModuleOp mod) {
}
/// Push up any wires which are simply passed-through.
void PassCommon::bubbleWiresUp(MSFTModuleOp mod) {
void MSFTPassCommon::bubbleWiresUp(MSFTModuleOp mod) {
Block *body = mod.getBodyBlock();
Operation *terminator = body->getTerminator();
hw::ModulePortInfo ports = mod.getPorts();
@ -1552,13 +1561,15 @@ void PassCommon::bubbleWiresUp(MSFTModuleOp mod) {
updateInstances(mod, newToOldResult, setPassthroughsGetOperands);
}
void PassCommon::dedupInputs(MSFTModuleOp mod) {
auto instantiations = moduleInstantiations[mod];
void MSFTPassCommon::dedupInputs(MSFTModuleOp mod) {
const auto &instantiations = moduleInstantiations[mod];
// TODO: remove this limitation. This would involve looking at the common
// loopbacks for all the instances.
if (instantiations.size() != 1)
return;
InstanceOp inst = instantiations[0];
InstanceOp inst = dyn_cast<InstanceOp>(instantiations[0]);
if (!inst)
return;
// Find all the arguments which are driven by the same signal. Remap them
// appropriately within the module, and mark that input port for deletion.
@ -1589,8 +1600,7 @@ void PassCommon::dedupInputs(MSFTModuleOp mod) {
if (!argsToErase.test(argNum))
newOperands.push_back(oldInst.getOperand(argNum));
};
instantiations = updateInstances(mod, remappedResults, getOperands);
inst = instantiations[0];
inst = updateInstances(mod, remappedResults, getOperands)[0];
SmallVector<Attribute, 32> newArgNames;
std::string buff;
@ -1602,13 +1612,15 @@ void PassCommon::dedupInputs(MSFTModuleOp mod) {
}
/// Sink all the instance connections which are loops.
void PassCommon::sinkWiresDown(MSFTModuleOp mod) {
auto instantiations = moduleInstantiations[mod];
void MSFTPassCommon::sinkWiresDown(MSFTModuleOp mod) {
const auto &instantiations = moduleInstantiations[mod];
// TODO: remove this limitation. This would involve looking at the common
// loopbacks for all the instances.
if (instantiations.size() != 1)
return;
InstanceOp inst = instantiations[0];
InstanceOp inst = dyn_cast<InstanceOp>(instantiations[0]);
if (!inst)
return;
// Find all the "loopback" connections in the instantiation. Populate
// 'inputToOutputLoopback' with a mapping of input port to output port which
@ -1852,7 +1864,7 @@ std::unique_ptr<Pass> createLowerConstructsPass() {
namespace {
struct DiscoverAppIDsPass : public DiscoverAppIDsBase<DiscoverAppIDsPass>,
PassCommon {
MSFTPassCommon {
void runOnOperation() override;
void processMod(MSFTModuleOp);
};