[circt-reduce] Infrastructure improvements, bug fixes, and additional tests (#5131)

This implements some improvements to the circt-reduce infrastructure, fixes a few bugs, and adds more tests. More precisely:
* Split the `Reduction.cpp` file into a file containing the Reduction pattern base class implementations, one file per dialect for dialect specific patterns, a file for generic patterns, a file containing utilities to be shared across the dialect and generic pattern files
* Move the files containing dialect specific patterns to the corresponding dialect folder and use a new dialect interface to register them in `circt-reduce`. This is done in a way that only the patterns of dialects which were loaded during parsing are registered.
* Add tests to check that the correct set of patterns are registered after parsing
* Also move the regression tests to the corresponding dialect folders, but keep the generic patterns and tool specific tests in the `circt-reduce` test folder
* Fix some existing reduce patterns: `hw-operand-forwarder` didn't check if the operand and result values were the same in which case circt-reduce would just crash, `hw-constantifier` only checked if one of the result values is an integer, but then assumed that all of them are integers in the rewrite function.
* In the `PassReduction` constructor, the nesting of the pass manager was not handled correctly for non-nested passes (it tried to apply it to `ModuleOp` operations nested inside the root `ModuleOp`, also add support for passes operating on `hw.module`
* Add `hw-module-externalizer` reduction pattern
* Add a simple pattern for the Arc dialect
* A few more minor improvements

As a case study I reverted the combinational loop fix of LowerState (#5068) and ran circt-reduce on largeBoomConfig on the HW/Comb level. Before it took 237 sec to get to final size of 2105 (as reported by circt-reduce). With this PR applied it takes 86 sec to get a size of 905. Additionally, I checked for rocket on the firrtl level. With this PR it takes 483 sec to get a size of 20827, while without the PR, I cancelled after 824 sec where it was at a size of 522332.
This commit is contained in:
Martin Erhart 2023-05-04 19:41:29 +02:00 committed by GitHub
parent 3146ede582
commit cf02fee4f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1001 additions and 330 deletions

View File

@ -0,0 +1,29 @@
//===- ArcReductions.h - Arc reduction interface declaration ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_DIALECT_ARC_ARCREDUCTIONS_H
#define CIRCT_DIALECT_ARC_ARCREDUCTIONS_H
#include "circt/Reduce/Reduction.h"
namespace circt {
namespace arc {
/// A dialect interface to provide reduction patterns to a reducer tool.
struct ArcReducePatternDialectInterface : public ReducePatternDialectInterface {
using ReducePatternDialectInterface::ReducePatternDialectInterface;
void populateReducePatterns(circt::ReducePatternSet &patterns) const override;
};
/// Register the Arc Reduction pattern dialect interface to the given registry.
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry);
} // namespace arc
} // namespace circt
#endif // CIRCT_DIALECT_ARC_ARCREDUCTIONS_H

View File

@ -0,0 +1,31 @@
//===- FIRRTLReductions.h - FIRRTL reduction interf. decl. ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_DIALECT_FIRRTL_FIRRTLREDUCTIONS_H
#define CIRCT_DIALECT_FIRRTL_FIRRTLREDUCTIONS_H
#include "circt/Reduce/Reduction.h"
namespace circt {
namespace firrtl {
/// A dialect interface to provide reduction patterns to a reducer tool.
struct FIRRTLReducePatternDialectInterface
: public ReducePatternDialectInterface {
using ReducePatternDialectInterface::ReducePatternDialectInterface;
void populateReducePatterns(circt::ReducePatternSet &patterns) const override;
};
/// Register the FIRRTL Reduction pattern dialect interface to the given
/// registry.
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry);
} // namespace firrtl
} // namespace circt
#endif // CIRCT_DIALECT_FIRRTL_FIRRTLREDUCTIONS_H

View File

@ -0,0 +1,29 @@
//===- HWReductions.h - HW reduction interface declaration ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_DIALECT_HW_HWREDUCTIONS_H
#define CIRCT_DIALECT_HW_HWREDUCTIONS_H
#include "circt/Reduce/Reduction.h"
namespace circt {
namespace hw {
/// A dialect interface to provide reduction patterns to a reducer tool.
struct HWReducePatternDialectInterface : public ReducePatternDialectInterface {
using ReducePatternDialectInterface::ReducePatternDialectInterface;
void populateReducePatterns(circt::ReducePatternSet &patterns) const override;
};
/// Register the HW Reduction pattern dialect interface to the given registry.
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry);
} // namespace hw
} // namespace circt
#endif // CIRCT_DIALECT_HW_HWREDUCTIONS_H

View File

@ -0,0 +1,23 @@
//===- GenericReductions.h - Generic reduction patterns ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_REDUCE_GENERICREDUCTIONS_H
#define CIRCT_REDUCE_GENERICREDUCTIONS_H
#include "circt/Reduce/Reduction.h"
namespace circt {
/// Populate reduction patterns that are not specific to certain operations or
/// dialects
void populateGenericReducePatterns(MLIRContext *context,
ReducePatternSet &patterns);
} // namespace circt
#endif // CIRCT_REDUCE_GENERICREDUCTIONS_H

View File

@ -1,4 +1,4 @@
//===- Reduction.h - Reductions for circt-reduce --------------------------===//
//===- Reduction.h - Reduction datastructure decl. for circt-reduce -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,31 +6,17 @@
//
//===----------------------------------------------------------------------===//
//
// This file defines abstract reduction patterns for the 'circt-reduce' tool.
// This file defines datastructures to handle reduction patterns.
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_REDUCE_REDUCTION_H
#define CIRCT_REDUCE_REDUCTION_H
#include <memory>
#include <string>
#include "circt/Support/LLVM.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
template <typename T>
class function_ref;
} // namespace llvm
namespace mlir {
struct LogicalResult;
class MLIRContext;
class Operation;
class Pass;
class PassManager;
} // namespace mlir
#include "mlir/Pass/PassManager.h"
#include "llvm/ADT/SmallVector.h"
namespace circt {
@ -52,12 +38,12 @@ struct Reduction {
/// benefit measure where a higher number means that applying the pattern
/// leads to a bigger reduction and zero means that the patten does not
/// match and thus cannot be applied at all.
virtual uint64_t match(mlir::Operation *op) = 0;
virtual uint64_t match(Operation *op) = 0;
/// Apply the reduction to a specific operation. If the returned result
/// indicates that the application failed, the resulting module is treated the
/// same as if the tester marked it as uninteresting.
virtual mlir::LogicalResult rewrite(mlir::Operation *op) = 0;
virtual LogicalResult rewrite(Operation *op) = 0;
/// Return a human-readable name for this reduction pattern.
virtual std::string getName() const = 0;
@ -86,37 +72,81 @@ struct Reduction {
virtual bool isOneShot() const { return false; }
/// An optional callback for reductions to communicate removal of operations.
std::function<void(mlir::Operation *)> notifyOpErasedCallback = nullptr;
std::function<void(Operation *)> notifyOpErasedCallback = nullptr;
void notifyOpErased(mlir::Operation *op) {
void notifyOpErased(Operation *op) {
if (notifyOpErasedCallback)
notifyOpErasedCallback(op);
}
};
template <typename OpTy>
struct OpReduction : public Reduction {
uint64_t match(Operation *op) override {
if (auto concreteOp = dyn_cast<OpTy>(op))
return match(concreteOp);
return 0;
}
LogicalResult rewrite(Operation *op) override {
return rewrite(cast<OpTy>(op));
}
virtual uint64_t match(OpTy op) { return 1; }
virtual LogicalResult rewrite(OpTy op) = 0;
};
/// A reduction pattern that applies an `mlir::Pass`.
struct PassReduction : public Reduction {
PassReduction(mlir::MLIRContext *context, std::unique_ptr<mlir::Pass> pass,
PassReduction(MLIRContext *context, std::unique_ptr<Pass> pass,
bool canIncreaseSize = false, bool oneShot = false);
uint64_t match(mlir::Operation *op) override;
mlir::LogicalResult rewrite(mlir::Operation *op) override;
uint64_t match(Operation *op) override;
LogicalResult rewrite(Operation *op) override;
std::string getName() const override;
bool acceptSizeIncrease() const override { return canIncreaseSize; }
bool isOneShot() const override { return oneShot; }
protected:
mlir::MLIRContext *const context;
MLIRContext *const context;
std::unique_ptr<mlir::PassManager> pm;
llvm::StringRef passName;
StringRef passName;
bool canIncreaseSize;
bool oneShot;
};
/// Calls the function `add` with each available reduction, in the order they
/// should be applied.
void createAllReductions(
mlir::MLIRContext *context,
llvm::function_ref<void(std::unique_ptr<Reduction>)> add);
class ReducePatternSet {
public:
template <typename R, unsigned Benefit, typename... Args>
void add(Args &&...args) {
reducePatternsWithBenefit.push_back(
{std::make_unique<R>(std::forward<Args>(args)...), Benefit});
}
void filter(const std::function<bool(const Reduction &)> &pred);
void sortByBenefit();
size_t size() const;
Reduction &operator[](size_t idx) const;
private:
SmallVector<std::pair<std::unique_ptr<Reduction>, unsigned>>
reducePatternsWithBenefit;
};
/// A dialect interface to provide reduction patterns to a reducer tool.
struct ReducePatternDialectInterface
: public mlir::DialectInterface::Base<ReducePatternDialectInterface> {
ReducePatternDialectInterface(Dialect *dialect) : Base(dialect) {}
virtual void populateReducePatterns(ReducePatternSet &patterns) const = 0;
};
struct ReducePatternInterfaceCollection
: public mlir::DialectInterfaceCollection<ReducePatternDialectInterface> {
using Base::Base;
// Collect the reduce patterns defined by each dialect.
void populateReducePatterns(ReducePatternSet &patterns) const;
};
} // namespace circt

View File

@ -0,0 +1,27 @@
//===- ReductionUtils.h - Reduction pattern utilities -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_REDUCE_REDUCTIONUTILS_H
#define CIRCT_REDUCE_REDUCTIONUTILS_H
#include "circt/Support/LLVM.h"
namespace circt {
// Forward declarations.
struct Reduction;
namespace reduce {
/// Starting at the given `op`, traverse through it and its operands and erase
/// operations that have no more uses.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction);
} // namespace reduce
} // namespace circt
#endif // CIRCT_REDUCE_REDUCTIONUTILS_H

View File

@ -13,9 +13,6 @@
#ifndef CIRCT_REDUCE_TESTER_H
#define CIRCT_REDUCE_TESTER_H
#include <memory>
#include <vector>
#include "circt/Support/LLVM.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/SmallString.h"

View File

@ -73,11 +73,14 @@ namespace llvm {
template <typename KeyT, typename ValueT, unsigned InlineBuckets,
typename KeyInfoT, typename BucketT>
class SmallDenseMap;
template <typename T, unsigned N, typename C>
class SmallSet;
} // namespace llvm
// Import things we want into our namespace.
namespace circt {
using llvm::SmallDenseMap;
using llvm::SmallSet;
} // namespace circt
// Forward declarations of classes to be imported in to the circt namespace.

View File

@ -7,6 +7,7 @@ add_subdirectory(Bindings)
add_subdirectory(CAPI)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(Reduce)
add_subdirectory(Scheduling)
add_subdirectory(Support)
add_subdirectory(Target)

View File

@ -0,0 +1,67 @@
//===- ArcReductions.cpp - Reduction patterns for the Arc Dialect -=-------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Arc/ArcReductions.h"
#include "circt/Dialect/Arc/ArcOps.h"
#include "circt/Dialect/Arc/ArcPasses.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "arc-reductions"
using namespace circt;
using namespace arc;
//===----------------------------------------------------------------------===//
// Reduction patterns
//===----------------------------------------------------------------------===//
/// A sample reduction pattern that converts `arc.state` operations to the
/// simpler `arc.call` operation and removes clock, latency, name attributes,
/// enables, and resets in the process.
struct StateElimination : public OpReduction<StateOp> {
LogicalResult rewrite(StateOp stateOp) override {
OpBuilder builder(stateOp);
ValueRange results =
builder
.create<arc::CallOp>(stateOp.getLoc(), stateOp->getResultTypes(),
stateOp.getArcAttr(), stateOp.getInputs())
->getResults();
stateOp.replaceAllUsesWith(results);
stateOp.erase();
return success();
}
std::string getName() const override { return "arc-state-elimination"; }
};
//===----------------------------------------------------------------------===//
// Reduction Registration
//===----------------------------------------------------------------------===//
void ArcReducePatternDialectInterface::populateReducePatterns(
circt::ReducePatternSet &patterns) const {
// Gather a list of reduction patterns that we should try. Ideally these are
// assigned reasonable benefit indicators (higher benefit patterns are
// prioritized). For example, things that can knock out entire modules while
// being cheap should be tried first (and thus have higher benefit), before
// trying to tweak operands of individual arithmetic ops.
patterns.add<PassReduction, 4>(getContext(), arc::createStripSVPass(), true,
true);
patterns.add<PassReduction, 3>(getContext(), arc::createDedupPass());
patterns.add<StateElimination, 2>();
patterns.add<PassReduction, 1>(getContext(), arc::createSinkInputsPass());
}
void arc::registerReducePatternDialectInterface(
mlir::DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, ArcDialect *dialect) {
dialect->addInterfaces<ArcReducePatternDialectInterface>();
});
}

View File

@ -1,3 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
ArcReductions.cpp
)
add_circt_dialect_library(CIRCTArc
ArcDialect.cpp
ArcFolds.cpp
@ -22,6 +26,17 @@ add_circt_dialect_library(CIRCTArc
MLIRSideEffectInterfaces
)
add_circt_library(CIRCTArcReductions
ArcReductions.cpp
LINK_LIBS PUBLIC
CIRCTReduceLib
CIRCTArc
CIRCTArcTransforms
CIRCTConvertToArcs
MLIRIR
)
add_dependencies(circt-headers
MLIRArcIncGen
)

View File

@ -1,3 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
FIRRTLReductions.cpp
)
include_directories(.)
add_circt_dialect_library(CIRCTFIRRTL
CHIRRTLDialect.cpp
@ -35,6 +39,16 @@ add_circt_dialect_library(CIRCTFIRRTL
MLIRPass
)
add_circt_library(CIRCTFIRRTLReductions
FIRRTLReductions.cpp
LINK_LIBS PUBLIC
CIRCTReduceLib
CIRCTFIRRTL
CIRCTFIRRTLTransforms
MLIRIR
)
add_dependencies(circt-headers
MLIRFIRRTLIncGen
CIRCTFIRRTLEnumsIncGen

View File

@ -1,35 +1,22 @@
//===- Reduction.cpp - Reductions for circt-reduce ------------------------===//
//===- FIRRTLReductions.cpp - Reduction patterns for the FIRRTL dialect ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines abstract reduction patterns for the 'circt-reduce' tool.
//
//===----------------------------------------------------------------------===//
#include "Reduction.h"
#include "circt/Dialect/FIRRTL/FIRRTLReductions.h"
#include "circt/Dialect/FIRRTL/FIRRTLOps.h"
#include "circt/Dialect/FIRRTL/Passes.h"
#include "circt/InitAllDialects.h"
#include "mlir/IR/AsmState.h"
#include "circt/Reduce/ReductionUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "circt-reduce"
#define DEBUG_TYPE "firrtl-reductions"
using namespace llvm;
using namespace mlir;
using namespace circt;
@ -179,48 +166,11 @@ struct NLARemover {
};
//===----------------------------------------------------------------------===//
// Reduction
//===----------------------------------------------------------------------===//
Reduction::~Reduction() = default;
//===----------------------------------------------------------------------===//
// Pass Reduction
//===----------------------------------------------------------------------===//
PassReduction::PassReduction(MLIRContext *context, std::unique_ptr<Pass> pass,
bool canIncreaseSize, bool oneShot)
: context(context), canIncreaseSize(canIncreaseSize), oneShot(oneShot) {
passName = pass->getArgument();
if (passName.empty())
passName = pass->getName();
pm = std::make_unique<PassManager>(context, "builtin.module",
mlir::OpPassManager::Nesting::Explicit);
auto opName = pass->getOpName();
if (opName && opName->equals("firrtl.circuit"))
pm->nest<firrtl::CircuitOp>().addPass(std::move(pass));
else if (opName && opName->equals("firrtl.module"))
pm->nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>().addPass(
std::move(pass));
else
pm->nest<mlir::ModuleOp>().addPass(std::move(pass));
}
uint64_t PassReduction::match(Operation *op) {
return op->getName() == pm->getOpName(*context);
}
LogicalResult PassReduction::rewrite(Operation *op) { return pm->run(op); }
std::string PassReduction::getName() const { return passName.str(); }
//===----------------------------------------------------------------------===//
// Concrete Sample Reductions (to later move into the dialects)
// Reduction patterns
//===----------------------------------------------------------------------===//
/// A sample reduction pattern that maps `firrtl.module` to `firrtl.extmodule`.
struct ModuleExternalizer : public Reduction {
struct FIRRTLModuleExternalizer : public OpReduction<firrtl::FModuleOp> {
void beforeReduction(mlir::ModuleOp op) override {
nlaRemover.clear();
symbols.clear();
@ -228,15 +178,12 @@ struct ModuleExternalizer : public Reduction {
}
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
uint64_t match(Operation *op) override {
if (isa<firrtl::FModuleOp>(op))
return moduleSizes.getModuleSize(op, symbols);
return 0;
uint64_t match(firrtl::FModuleOp module) override {
return moduleSizes.getModuleSize(module, symbols);
}
LogicalResult rewrite(Operation *op) override {
auto module = cast<firrtl::FModuleOp>(op);
nlaRemover.markNLAsInOperation(op);
LogicalResult rewrite(firrtl::FModuleOp module) override {
nlaRemover.markNLAsInOperation(module);
OpBuilder builder(module);
builder.create<firrtl::FExtModuleOp>(
module->getLoc(),
@ -247,7 +194,7 @@ struct ModuleExternalizer : public Reduction {
return success();
}
std::string getName() const override { return "module-externalizer"; }
std::string getName() const override { return "firrtl-module-externalizer"; }
SymbolCache symbols;
NLARemover nlaRemover;
@ -365,7 +312,7 @@ static void reduceXor(ImplicitLocOpBuilder &builder, Value &into, Value value) {
/// A sample reduction pattern that maps `firrtl.instance` to a set of
/// invalidated wires. This often shortcuts a long iterative process of connect
/// invalidation, module externalization, and wire stripping
struct InstanceStubber : public Reduction {
struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
void beforeReduction(mlir::ModuleOp op) override {
erasedInsts.clear();
erasedModules.clear();
@ -405,15 +352,13 @@ struct InstanceStubber : public Reduction {
nlaRemover.remove(op);
}
uint64_t match(Operation *op) override {
if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
if (auto fmoduleOp = findInstantiatedModule(instOp, symbols))
return moduleSizes.getModuleSize(*fmoduleOp, symbols);
uint64_t match(firrtl::InstanceOp instOp) override {
if (auto fmoduleOp = findInstantiatedModule(instOp, symbols))
return moduleSizes.getModuleSize(*fmoduleOp, symbols);
return 0;
}
LogicalResult rewrite(Operation *op) override {
auto instOp = cast<firrtl::InstanceOp>(op);
LogicalResult rewrite(firrtl::InstanceOp instOp) override {
LLVM_DEBUG(llvm::dbgs()
<< "Stubbing instance `" << instOp.getName() << "`\n");
ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
@ -458,12 +403,10 @@ struct InstanceStubber : public Reduction {
/// A sample reduction pattern that maps `firrtl.mem` to a set of invalidated
/// wires.
struct MemoryStubber : public Reduction {
struct MemoryStubber : public OpReduction<firrtl::MemOp> {
void beforeReduction(mlir::ModuleOp op) override { nlaRemover.clear(); }
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
uint64_t match(Operation *op) override { return isa<firrtl::MemOp>(op); }
LogicalResult rewrite(Operation *op) override {
auto memOp = cast<firrtl::MemOp>(op);
LogicalResult rewrite(firrtl::MemOp memOp) override {
LLVM_DEBUG(llvm::dbgs() << "Stubbing memory `" << memOp.getName() << "`\n");
ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
SmallDenseMap<Type, Value, 8> invalidCache;
@ -531,25 +474,6 @@ struct MemoryStubber : public Reduction {
NLARemover nlaRemover;
};
/// Starting at the given `op`, traverse through it and its operands and erase
/// operations that have no more uses.
static void pruneUnusedOps(Operation *initialOp, Reduction &reduction) {
SmallVector<Operation *> worklist;
SmallSet<Operation *, 4> handled;
worklist.push_back(initialOp);
while (!worklist.empty()) {
auto *op = worklist.pop_back_val();
if (!op->use_empty())
continue;
for (auto arg : op->getOperands())
if (auto *argOp = arg.getDefiningOp())
if (handled.insert(argOp).second)
worklist.push_back(argOp);
reduction.notifyOpErased(op);
op->erase();
}
}
/// Check whether an operation interacts with flows in any way, which can make
/// replacement and operand forwarding harder in some cases.
static bool isFlowSensitiveOp(Operation *op) {
@ -598,7 +522,7 @@ struct FIRRTLOperandForwarder : public Reduction {
newOp = operand;
LLVM_DEBUG(llvm::dbgs() << "Forwarding " << newOp << " in " << *op << "\n");
result.replaceAllUsesWith(newOp);
pruneUnusedOps(op, *this);
reduce::pruneUnusedOps(op, *this);
return success();
}
std::string getName() const override {
@ -606,35 +530,6 @@ struct FIRRTLOperandForwarder : public Reduction {
}
};
/// A sample reduction pattern that replaces all uses of an operation with one
/// of its operands. This can help pruning large parts of the expression tree
/// rapidly.
template <unsigned OpNum>
struct HWOperandForwarder : public Reduction {
uint64_t match(Operation *op) override {
if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
OpNum >= op->getNumOperands())
return 0;
auto resultTy = op->getResult(0).getType().dyn_cast<IntegerType>();
auto opTy = op->getOperand(OpNum).getType().dyn_cast<IntegerType>();
return resultTy && opTy && resultTy == opTy;
}
LogicalResult rewrite(Operation *op) override {
assert(match(op));
ImplicitLocOpBuilder builder(op->getLoc(), op);
auto result = op->getResult(0);
auto operand = op->getOperand(OpNum);
LLVM_DEBUG(llvm::dbgs()
<< "Forwarding " << operand << " in " << *op << "\n");
result.replaceAllUsesWith(operand);
pruneUnusedOps(op, *this);
return success();
}
std::string getName() const override {
return ("hw-operand" + Twine(OpNum) + "-forwarder").str();
}
};
/// A sample reduction pattern that replaces FIRRTL operations with a constant
/// zero of their type.
struct FIRRTLConstantifier : public Reduction {
@ -656,36 +551,12 @@ struct FIRRTLConstantifier : public Reduction {
auto newOp = builder.create<firrtl::ConstantOp>(
op->getLoc(), type, APSInt(width, type.isa<firrtl::UIntType>()));
op->replaceAllUsesWith(newOp);
pruneUnusedOps(op, *this);
reduce::pruneUnusedOps(op, *this);
return success();
}
std::string getName() const override { return "firrtl-constantifier"; }
};
/// A sample reduction pattern that replaces integer operations with a constant
/// zero of their type.
struct HWConstantifier : public Reduction {
uint64_t match(Operation *op) override {
if (op->getNumResults() == 0 || op->getNumOperands() == 0)
return 0;
return llvm::any_of(op->getResults(), [](Value result) {
return result.getType().isa<IntegerType>();
});
}
LogicalResult rewrite(Operation *op) override {
assert(match(op));
OpBuilder builder(op);
for (auto result : op->getResults()) {
auto type = result.getType().cast<IntegerType>();
auto newOp = builder.create<hw::ConstantOp>(op->getLoc(), type, 0);
result.replaceAllUsesWith(newOp);
}
pruneUnusedOps(op, *this);
return success();
}
std::string getName() const override { return "hw-constantifier"; }
};
/// A sample reduction pattern that replaces the right-hand-side of
/// `firrtl.connect` and `firrtl.strictconnect` operations with a
/// `firrtl.invalidvalue`. This removes uses from the fanin cone to these
@ -693,7 +564,7 @@ struct HWConstantifier : public Reduction {
struct ConnectInvalidator : public Reduction {
uint64_t match(Operation *op) override {
if (!isa<firrtl::ConnectOp, firrtl::StrictConnectOp>(op))
return false;
return 0;
auto type = op->getOperand(1).getType().dyn_cast<firrtl::FIRRTLBaseType>();
return type && type.isPassive() &&
!op->getOperand(1).getDefiningOp<firrtl::InvalidValueOp>();
@ -707,33 +578,13 @@ struct ConnectInvalidator : public Reduction {
auto *rhsOp = rhs.getDefiningOp();
op->setOperand(1, invOp);
if (rhsOp)
pruneUnusedOps(rhsOp, *this);
reduce::pruneUnusedOps(rhsOp, *this);
return success();
}
std::string getName() const override { return "connect-invalidator"; }
bool acceptSizeIncrease() const override { return true; }
};
/// A sample reduction pattern that removes operations which either produce no
/// results or their results have no users.
struct OperationPruner : public Reduction {
void beforeReduction(mlir::ModuleOp op) override { symbols.clear(); }
uint64_t match(Operation *op) override {
return !isa<ModuleOp>(op) &&
(op->getNumResults() == 0 || op->use_empty()) &&
(!op->hasAttr(SymbolTable::getSymbolAttrName()) ||
symbols.getNearestSymbolUserMap(op).useEmpty(op));
}
LogicalResult rewrite(Operation *op) override {
assert(match(op));
pruneUnusedOps(op, *this);
return success();
}
std::string getName() const override { return "operation-pruner"; }
SymbolCache symbols;
};
/// A sample reduction pattern that removes FIRRTL annotations from ports and
/// operations.
struct AnnotationRemover : public Reduction {
@ -765,19 +616,15 @@ struct AnnotationRemover : public Reduction {
/// A sample reduction pattern that removes ports from the root `firrtl.module`
/// if the port is not used or just invalidated.
struct RootPortPruner : public Reduction {
uint64_t match(Operation *op) override {
auto module = dyn_cast<firrtl::FModuleOp>(op);
if (!module)
return 0;
struct RootPortPruner : public OpReduction<firrtl::FModuleOp> {
uint64_t match(firrtl::FModuleOp module) override {
auto circuit = module->getParentOfType<firrtl::CircuitOp>();
if (!circuit)
return 0;
return circuit.getNameAttr() == module.getNameAttr();
}
LogicalResult rewrite(Operation *op) override {
assert(match(op));
auto module = cast<firrtl::FModuleOp>(op);
LogicalResult rewrite(firrtl::FModuleOp module) override {
assert(match(module));
size_t numPorts = module.getNumPorts();
llvm::BitVector dropPorts(numPorts);
for (unsigned i = 0; i != numPorts; ++i) {
@ -796,21 +643,18 @@ struct RootPortPruner : public Reduction {
/// A sample reduction pattern that replaces instances of `firrtl.extmodule`
/// with wires.
struct ExtmoduleInstanceRemover : public Reduction {
struct ExtmoduleInstanceRemover : public OpReduction<firrtl::InstanceOp> {
void beforeReduction(mlir::ModuleOp op) override {
symbols.clear();
nlaRemover.clear();
}
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
uint64_t match(Operation *op) override {
if (auto instOp = dyn_cast<firrtl::InstanceOp>(op))
return isa<firrtl::FExtModuleOp>(
instOp.getReferencedModule(symbols.getNearestSymbolTable(instOp)));
return 0;
uint64_t match(firrtl::InstanceOp instOp) override {
return isa<firrtl::FExtModuleOp>(
instOp.getReferencedModule(symbols.getNearestSymbolTable(instOp)));
}
LogicalResult rewrite(Operation *op) override {
auto instOp = cast<firrtl::InstanceOp>(op);
LogicalResult rewrite(firrtl::InstanceOp instOp) override {
auto portInfo =
instOp.getReferencedModule(symbols.getNearestSymbolTable(instOp))
.getPorts();
@ -959,7 +803,7 @@ struct ConnectSourceOperandForwarder : public Reduction {
// because destination has only one use.
op->erase();
destOp->erase();
pruneUnusedOps(srcOp, *this);
reduce::pruneUnusedOps(srcOp, *this);
return success();
}
@ -1019,17 +863,14 @@ struct DetachSubaccesses : public Reduction {
/// This reduction removes symbols on node ops. Name preservation creates a lot
/// of nodes ops with symbols to keep name information but it also prevents
/// normal canonicalizations.
struct NodeSymbolRemover : public Reduction {
struct NodeSymbolRemover : public OpReduction<firrtl::NodeOp> {
uint64_t match(Operation *op) override {
if (auto nodeOp = dyn_cast<firrtl::NodeOp>(op))
return nodeOp.getInnerSym() &&
!nodeOp.getInnerSym()->getSymName().getValue().empty();
return 0;
uint64_t match(firrtl::NodeOp nodeOp) override {
return nodeOp.getInnerSym() &&
!nodeOp.getInnerSym()->getSymName().getValue().empty();
}
LogicalResult rewrite(Operation *op) override {
auto nodeOp = cast<firrtl::NodeOp>(op);
LogicalResult rewrite(firrtl::NodeOp nodeOp) override {
nodeOp.removeInnerSymAttr();
return success();
}
@ -1038,17 +879,14 @@ struct NodeSymbolRemover : public Reduction {
};
/// A sample reduction pattern that eagerly inlines instances.
struct EagerInliner : public Reduction {
struct EagerInliner : public OpReduction<firrtl::InstanceOp> {
void beforeReduction(mlir::ModuleOp op) override {
symbols.clear();
nlaRemover.clear();
}
void afterReduction(mlir::ModuleOp op) override { nlaRemover.remove(op); }
uint64_t match(Operation *op) override {
auto instOp = dyn_cast<firrtl::InstanceOp>(op);
if (!instOp)
return 0;
uint64_t match(firrtl::InstanceOp instOp) override {
auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
auto moduleOp = instOp.getReferencedModule(symbols.getSymbolTable(tableOp));
if (!isa<firrtl::FModuleOp>(moduleOp.getOperation()))
@ -1056,8 +894,7 @@ struct EagerInliner : public Reduction {
return symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1;
}
LogicalResult rewrite(Operation *op) override {
auto instOp = cast<firrtl::InstanceOp>(op);
LogicalResult rewrite(firrtl::InstanceOp instOp) override {
LLVM_DEBUG(llvm::dbgs()
<< "Inlining instance `" << instOp.getName() << "`\n");
SmallVector<Value> argReplacements;
@ -1103,59 +940,52 @@ struct EagerInliner : public Reduction {
// Reduction Registration
//===----------------------------------------------------------------------===//
static std::unique_ptr<Pass> createSimpleCanonicalizerPass() {
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification = false;
return createCanonicalizerPass(config);
void firrtl::FIRRTLReducePatternDialectInterface::populateReducePatterns(
circt::ReducePatternSet &patterns) const {
// Gather a list of reduction patterns that we should try. Ideally these are
// assigned reasonable benefit indicators (higher benefit patterns are
// prioritized). For example, things that can knock out entire modules while
// being cheap should be tried first (and thus have higher benefit), before
// trying to tweak operands of individual arithmetic ops.
patterns.add<PassReduction, 29>(getContext(),
firrtl::createLowerCHIRRTLPass(), true, true);
patterns.add<PassReduction, 28>(getContext(), firrtl::createInferWidthsPass(),
true, true);
patterns.add<PassReduction, 27>(getContext(), firrtl::createInferResetsPass(),
true, true);
patterns.add<FIRRTLModuleExternalizer, 26>();
patterns.add<InstanceStubber, 25>();
patterns.add<MemoryStubber, 24>();
patterns.add<EagerInliner, 23>();
patterns.add<PassReduction, 22>(
getContext(), firrtl::createLowerFIRRTLTypesPass(), true, true);
patterns.add<PassReduction, 21>(getContext(), firrtl::createExpandWhensPass(),
true, true);
patterns.add<PassReduction, 20>(getContext(), firrtl::createInlinerPass());
patterns.add<PassReduction, 18>(getContext(),
firrtl::createIMConstPropPass());
patterns.add<PassReduction, 17>(
getContext(),
firrtl::createRemoveUnusedPortsPass(/*ignoreDontTouch=*/true));
patterns.add<NodeSymbolRemover, 15>();
patterns.add<ConnectForwarder, 14>();
patterns.add<ConnectInvalidator, 13>();
patterns.add<FIRRTLConstantifier, 12>();
patterns.add<FIRRTLOperandForwarder<0>, 11>();
patterns.add<FIRRTLOperandForwarder<1>, 10>();
patterns.add<FIRRTLOperandForwarder<2>, 9>();
patterns.add<DetachSubaccesses, 7>();
patterns.add<AnnotationRemover, 6>();
patterns.add<RootPortPruner, 5>();
patterns.add<ExtmoduleInstanceRemover, 4>();
patterns.add<ConnectSourceOperandForwarder<0>, 3>();
patterns.add<ConnectSourceOperandForwarder<1>, 2>();
patterns.add<ConnectSourceOperandForwarder<2>, 1>();
}
void circt::createAllReductions(
MLIRContext *context,
llvm::function_ref<void(std::unique_ptr<Reduction>)> add) {
// Gather a list of reduction patterns that we should try. Ideally these are
// sorted by decreasing reduction potential/benefit. For example, things that
// can knock out entire modules while being cheap should be tried first,
// before trying to tweak operands of individual arithmetic ops.
add(std::make_unique<PassReduction>(context, firrtl::createLowerCHIRRTLPass(),
true, true));
add(std::make_unique<PassReduction>(context, firrtl::createInferWidthsPass(),
true, true));
add(std::make_unique<PassReduction>(context, firrtl::createInferResetsPass(),
true, true));
add(std::make_unique<ModuleExternalizer>());
add(std::make_unique<InstanceStubber>());
add(std::make_unique<MemoryStubber>());
add(std::make_unique<EagerInliner>());
add(std::make_unique<PassReduction>(
context, firrtl::createLowerFIRRTLTypesPass(), true, true));
add(std::make_unique<PassReduction>(context, firrtl::createExpandWhensPass(),
true, true));
add(std::make_unique<PassReduction>(context, firrtl::createInlinerPass()));
add(std::make_unique<PassReduction>(context,
createSimpleCanonicalizerPass()));
add(std::make_unique<PassReduction>(context,
firrtl::createIMConstPropPass()));
add(std::make_unique<PassReduction>(
context, firrtl::createRemoveUnusedPortsPass(/*ignoreDontTouch=*/true)));
add(std::make_unique<PassReduction>(context, createCSEPass()));
add(std::make_unique<NodeSymbolRemover>());
add(std::make_unique<ConnectForwarder>());
add(std::make_unique<ConnectInvalidator>());
add(std::make_unique<FIRRTLConstantifier>());
add(std::make_unique<HWConstantifier>());
add(std::make_unique<FIRRTLOperandForwarder<0>>());
add(std::make_unique<FIRRTLOperandForwarder<1>>());
add(std::make_unique<FIRRTLOperandForwarder<2>>());
add(std::make_unique<HWOperandForwarder<0>>());
add(std::make_unique<HWOperandForwarder<1>>());
add(std::make_unique<HWOperandForwarder<2>>());
add(std::make_unique<OperationPruner>());
add(std::make_unique<DetachSubaccesses>());
add(std::make_unique<AnnotationRemover>());
add(std::make_unique<RootPortPruner>());
add(std::make_unique<ExtmoduleInstanceRemover>());
add(std::make_unique<ConnectSourceOperandForwarder<0>>());
add(std::make_unique<ConnectSourceOperandForwarder<1>>());
add(std::make_unique<ConnectSourceOperandForwarder<2>>());
void firrtl::registerReducePatternDialectInterface(
mlir::DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, FIRRTLDialect *dialect) {
dialect->addInterfaces<FIRRTLReducePatternDialectInterface>();
});
}

View File

@ -1,5 +1,8 @@
add_circt_dialect_library(
CIRCTHW
set(LLVM_OPTIONAL_SOURCES
HWReductions.cpp
)
add_circt_dialect_library(CIRCTHW
CustomDirectiveImpl.cpp
HWAttributes.cpp
HWDialect.cpp
@ -30,4 +33,13 @@ add_circt_dialect_library(
MLIRInferTypeOpInterface
)
add_circt_library(CIRCTHWReductions
HWReductions.cpp
LINK_LIBS PUBLIC
CIRCTReduceLib
CIRCTHW
MLIRIR
)
add_subdirectory(Transforms)

View File

@ -0,0 +1,156 @@
//===- HWReductions.cpp - Reduction patterns for the HW dialect -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/HW/HWReductions.h"
#include "circt/Dialect/HW/HWInstanceGraph.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Reduce/ReductionUtils.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "hw-reductions"
using namespace mlir;
using namespace circt;
using namespace hw;
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
/// Utility to track the transitive size of modules.
struct ModuleSizeCache {
void clear() { moduleSizes.clear(); }
uint64_t getModuleSize(HWModuleLike module,
InstanceGraphBase &instanceGraph) {
if (auto it = moduleSizes.find(module); it != moduleSizes.end())
return it->second;
uint64_t size = 1;
module->walk([&](Operation *op) {
size += 1;
if (auto instOp = dyn_cast<HWInstanceLike>(op))
if (auto instModule = instanceGraph.getReferencedModule(instOp))
size += getModuleSize(instModule, instanceGraph);
});
moduleSizes.insert({module, size});
return size;
}
private:
llvm::DenseMap<Operation *, uint64_t> moduleSizes;
};
//===----------------------------------------------------------------------===//
// Reduction patterns
//===----------------------------------------------------------------------===//
/// A sample reduction pattern that maps `hw.module` to `hw.module.extern`.
struct ModuleExternalizer : public OpReduction<HWModuleOp> {
void beforeReduction(mlir::ModuleOp op) override {
instanceGraph = std::make_unique<InstanceGraph>(op);
moduleSizes.clear();
}
uint64_t match(HWModuleOp op) override {
return moduleSizes.getModuleSize(op, *instanceGraph);
}
LogicalResult rewrite(HWModuleOp op) override {
OpBuilder builder(op);
builder.create<HWModuleExternOp>(op->getLoc(), op.getModuleNameAttr(),
op.getPorts(), StringRef(),
op.getParameters());
op->erase();
return success();
}
std::string getName() const override { return "hw-module-externalizer"; }
std::unique_ptr<InstanceGraph> instanceGraph;
ModuleSizeCache moduleSizes;
};
/// A sample reduction pattern that replaces all uses of an operation with one
/// of its operands. This can help pruning large parts of the expression tree
/// rapidly.
template <unsigned OpNum>
struct HWOperandForwarder : public Reduction {
uint64_t match(Operation *op) override {
if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
OpNum >= op->getNumOperands())
return 0;
auto resultTy = op->getResult(0).getType().dyn_cast<IntegerType>();
auto opTy = op->getOperand(OpNum).getType().dyn_cast<IntegerType>();
return resultTy && opTy && resultTy == opTy &&
op->getResult(0) != op->getOperand(OpNum);
}
LogicalResult rewrite(Operation *op) override {
assert(match(op));
ImplicitLocOpBuilder builder(op->getLoc(), op);
auto result = op->getResult(0);
auto operand = op->getOperand(OpNum);
LLVM_DEBUG(llvm::dbgs()
<< "Forwarding " << operand << " in " << *op << "\n");
result.replaceAllUsesWith(operand);
reduce::pruneUnusedOps(op, *this);
return success();
}
std::string getName() const override {
return ("hw-operand" + Twine(OpNum) + "-forwarder").str();
}
};
/// A sample reduction pattern that replaces integer operations with a constant
/// zero of their type.
struct HWConstantifier : public Reduction {
uint64_t match(Operation *op) override {
if (op->getNumResults() == 0 || op->getNumOperands() == 0)
return 0;
return llvm::all_of(op->getResults(), [](Value result) {
return result.getType().isa<IntegerType>();
});
}
LogicalResult rewrite(Operation *op) override {
assert(match(op));
OpBuilder builder(op);
for (auto result : op->getResults()) {
auto type = result.getType().cast<IntegerType>();
auto newOp = builder.create<hw::ConstantOp>(op->getLoc(), type, 0);
result.replaceAllUsesWith(newOp);
}
reduce::pruneUnusedOps(op, *this);
return success();
}
std::string getName() const override { return "hw-constantifier"; }
};
//===----------------------------------------------------------------------===//
// Reduction Registration
//===----------------------------------------------------------------------===//
void HWReducePatternDialectInterface::populateReducePatterns(
circt::ReducePatternSet &patterns) const {
// Gather a list of reduction patterns that we should try. Ideally these are
// assigned reasonable benefit indicators (higher benefit patterns are
// prioritized). For example, things that can knock out entire modules while
// being cheap should be tried first (and thus have higher benefit), before
// trying to tweak operands of individual arithmetic ops.
patterns.add<ModuleExternalizer, 6>();
patterns.add<HWConstantifier, 5>();
patterns.add<HWOperandForwarder<0>, 4>();
patterns.add<HWOperandForwarder<1>, 3>();
patterns.add<HWOperandForwarder<2>, 2>();
}
void hw::registerReducePatternDialectInterface(
mlir::DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, HWDialect *dialect) {
dialect->addInterfaces<HWReducePatternDialectInterface>();
});
}

12
lib/Reduce/CMakeLists.txt Normal file
View File

@ -0,0 +1,12 @@
add_circt_library(CIRCTReduceLib
GenericReductions.cpp
Reduction.cpp
ReductionUtils.cpp
Tester.cpp
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
MLIRTransforms
MLIRReduceLib
)

View File

@ -0,0 +1,62 @@
//===- GenericReductions.cpp - Generic Reduction patterns -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Reduce/GenericReductions.h"
#include "circt/Reduce/ReductionUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
using namespace circt;
//===----------------------------------------------------------------------===//
// Reduction Patterns
//===----------------------------------------------------------------------===//
/// A sample reduction pattern that removes operations which either produce no
/// results or their results have no users.
struct OperationPruner : public Reduction {
uint64_t match(Operation *op) override {
if (!isa<ModuleOp>(op) && (op->getNumResults() == 0 || op->use_empty()) &&
!op->hasAttr(SymbolTable::getSymbolAttrName()))
return true;
auto *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
SymbolUserMap userMap(table, symbolTableOp);
return !isa<ModuleOp>(op) &&
(op->getNumResults() == 0 || op->use_empty()) &&
userMap.useEmpty(op);
}
LogicalResult rewrite(Operation *op) override {
assert(match(op));
reduce::pruneUnusedOps(op, *this);
return success();
}
std::string getName() const override { return "operation-pruner"; }
SymbolTableCollection table;
};
//===----------------------------------------------------------------------===//
// Reduction Registration
//===----------------------------------------------------------------------===//
static std::unique_ptr<Pass> createSimpleCanonicalizerPass() {
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification = false;
return createCanonicalizerPass(config);
}
void circt::populateGenericReducePatterns(MLIRContext *context,
ReducePatternSet &patterns) {
patterns.add<PassReduction, 3>(context, createCSEPass());
patterns.add<PassReduction, 2>(context, createSimpleCanonicalizerPass());
patterns.add<OperationPruner, 1>();
}

98
lib/Reduce/Reduction.cpp Normal file
View File

@ -0,0 +1,98 @@
//===- Reduction.cpp - Reductions for circt-reduce ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines datastructures to handle reduction patterns.
//
//===----------------------------------------------------------------------===//
#include "circt/Reduce/Reduction.h"
#include "circt/Dialect/FIRRTL/FIRRTLOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "circt-reduce"
using namespace circt;
//===----------------------------------------------------------------------===//
// Reduction
//===----------------------------------------------------------------------===//
Reduction::~Reduction() = default;
//===----------------------------------------------------------------------===//
// Pass Reduction
//===----------------------------------------------------------------------===//
PassReduction::PassReduction(MLIRContext *context, std::unique_ptr<Pass> pass,
bool canIncreaseSize, bool oneShot)
: context(context), canIncreaseSize(canIncreaseSize), oneShot(oneShot) {
passName = pass->getArgument();
if (passName.empty())
passName = pass->getName();
pm = std::make_unique<mlir::PassManager>(
context, "builtin.module", mlir::OpPassManager::Nesting::Explicit);
auto opName = pass->getOpName();
if (opName && opName->equals("firrtl.circuit"))
pm->nest<firrtl::CircuitOp>().addPass(std::move(pass));
else if (opName && opName->equals("firrtl.module"))
pm->nest<firrtl::CircuitOp>().nest<firrtl::FModuleOp>().addPass(
std::move(pass));
else if (opName && opName->equals("hw.module"))
pm->nest<hw::HWModuleOp>().addPass(std::move(pass));
else
pm->addPass(std::move(pass));
}
uint64_t PassReduction::match(Operation *op) {
return op->getName() == pm->getOpName(*context);
}
LogicalResult PassReduction::rewrite(Operation *op) { return pm->run(op); }
std::string PassReduction::getName() const { return passName.str(); }
//===----------------------------------------------------------------------===//
// ReducePatternSet
//===----------------------------------------------------------------------===//
void ReducePatternSet::filter(
const std::function<bool(const Reduction &)> &pred) {
for (auto *iter = reducePatternsWithBenefit.begin();
iter != reducePatternsWithBenefit.end(); ++iter) {
if (!pred(*iter->first))
reducePatternsWithBenefit.erase(iter--);
}
}
void ReducePatternSet::sortByBenefit() {
llvm::stable_sort(reducePatternsWithBenefit,
[](const auto &pairA, const auto &pairB) {
return pairA.second > pairB.second;
});
}
size_t ReducePatternSet::size() const {
return reducePatternsWithBenefit.size();
}
Reduction &ReducePatternSet::operator[](size_t idx) const {
return *reducePatternsWithBenefit[idx].first;
}
//===----------------------------------------------------------------------===//
// ReducePatternInterfaceCollection
//===----------------------------------------------------------------------===//
void ReducePatternInterfaceCollection::populateReducePatterns(
ReducePatternSet &patterns) const {
for (const ReducePatternDialectInterface &interface : *this)
interface.populateReducePatterns(patterns);
}

View File

@ -0,0 +1,31 @@
//===- ReductionUtils.cpp - Reduction pattern utilities -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Reduce/ReductionUtils.h"
#include "circt/Reduce/Reduction.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallSet.h"
using namespace circt;
void reduce::pruneUnusedOps(Operation *initialOp, Reduction &reduction) {
SmallVector<Operation *> worklist;
SmallSet<Operation *, 4> handled;
worklist.push_back(initialOp);
while (!worklist.empty()) {
auto *op = worklist.pop_back_val();
if (!op->use_empty())
continue;
for (auto arg : op->getOperands())
if (auto *argOp = arg.getDefiningOp())
if (handled.insert(argOp).second)
worklist.push_back(argOp);
reduction.notifyOpErased(op);
op->erase();
}
}

View File

@ -10,7 +10,7 @@
//
//===----------------------------------------------------------------------===//
#include "Tester.h"
#include "circt/Reduce/Tester.h"
#include "mlir/IR/Verifier.h"
#include "llvm/Support/ToolOutputFile.h"
@ -123,8 +123,8 @@ void TestCase::ensureFileOnDisk() {
// Pick a temporary output file path.
int fd;
std::error_code ec =
llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath);
std::error_code ec = llvm::sys::fs::createTemporaryFile(
"circt-reduce", "mlir", fd, filepath);
if (ec)
llvm::report_fatal_error(
Twine("Error making unique filename: ") + ec.message(), false);

View File

@ -0,0 +1,20 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// This test checks that only the reduction patterns of dialects that occur in
// the input file are registered
// RUN: circt-reduce %s --test /bin/cat --list | FileCheck %s
// CHECK: arc-strip-sv
// CHECK-NEXT: cse
// CHECK-NEXT: arc-dedup
// CHECK-NEXT: canonicalize
// CHECK-NEXT: arc-state-elimination
// CHECK-NEXT: operation-pruner
// CHECK-NEXT: arc-sink-inputs
// CHECK-EMPTY:
arc.define @DummyArc(%arg0: i32) -> i32 {
arc.output %arg0 : i32
}

View File

@ -0,0 +1,14 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "DummyArc(%arg0)" --keep-best=0 --include arc-state-elimination | FileCheck %s
// CHECK-LABEL: hw.module @Foo
hw.module @Foo(%clk: i1, %en: i1, %rst: i1, %arg0: i32) -> (out: i32) {
// CHECK-NEXT: [[V0:%.+]] = arc.call @DummyArc(%arg0) : (i32) -> i32
%0 = arc.state @DummyArc(%arg0) clock %clk enable %en reset %rst lat 1 {name="reg1"} : (i32) -> (i32)
// CHECK-NEXT: hw.output [[V0]]
hw.output %0 : i32
}
arc.define @DummyArc(%arg0: i32) -> i32 {
arc.output %arg0 : i32
}

View File

@ -1,6 +1,6 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test /usr/bin/env --test-arg grep --test-arg "%anotherWire = firrtl.wire" --keep-best=0 --include annotation-remover | FileCheck %s
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "%anotherWire = firrtl.wire" --keep-best=0 --include annotation-remover | FileCheck %s
firrtl.circuit "Foo" {
// CHECK: firrtl.module @Foo

View File

@ -1,6 +1,6 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test %S.sh --test-arg cat --test-arg "firrtl.module @Foo" --keep-best=0 --include connect-source-operand-0-forwarder --test-must-fail | FileCheck %s
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "firrtl.module @Foo" --keep-best=0 --include connect-source-operand-0-forwarder | FileCheck %s
firrtl.circuit "Foo" {
// CHECK-LABEL: firrtl.module @Foo
firrtl.module @Foo(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %val: !firrtl.uint<2>) {

View File

@ -1,6 +1,6 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test %S.sh --test-arg firtool --test-arg "error: sink \"x1.x\" not fully initialized" --keep-best=0 --include root-port-pruner --test-must-fail | FileCheck %s
// RUN: circt-reduce %s --test /bin/sh --test-arg -c --test-arg 'firtool "$0" 2>&1 | grep -q "error: sink \"x1.x\" not fully initialized"' --keep-best=0 --include root-port-pruner | FileCheck %s
// https://github.com/llvm/circt/issues/3555
firrtl.circuit "Foo" {

View File

@ -1,6 +1,6 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test %S.sh --test-arg cat --test-arg "firrtl.module @Basic" --keep-best=0 --include memory-stubber --test-must-fail | FileCheck %s
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "firrtl.module @Basic" --keep-best=0 --include memory-stubber | FileCheck %s
firrtl.circuit "Basic" {
// CHECK-LABEL: @Basic

View File

@ -1,6 +1,6 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test %S.sh --test-arg cat --test-arg "%anotherWire = firrtl.node" --keep-best=0 --include node-symbol-remover --test-must-fail | FileCheck %s
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "%anotherWire = firrtl.node" --keep-best=0 --include node-symbol-remover | FileCheck %s
firrtl.circuit "Foo" {
// CHECK: firrtl.module @Foo

View File

@ -0,0 +1,46 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// This test checks that only the reduction patterns of dialects that occur in
// the input file are registered
// RUN: circt-reduce %s --test /bin/cat --list | FileCheck %s
// CHECK: firrtl-lower-chirrtl
// CHECK-NEXT: firrtl-infer-widths
// CHECK-NEXT: firrtl-infer-resets
// CHECK-NEXT: firrtl-module-externalizer
// CHECK-NEXT: instance-stubber
// CHECK-NEXT: memory-stubber
// CHECK-NEXT: eager-inliner
// CHECK-NEXT: firrtl-lower-types
// CHECK-NEXT: firrtl-expand-whens
// CHECK-NEXT: firrtl-inliner
// CHECK-NEXT: firrtl-imconstprop
// CHECK-NEXT: firrtl-remove-unused-ports
// CHECK-NEXT: node-symbol-remover
// CHECK-NEXT: connect-forwarder
// CHECK-NEXT: connect-invalidator
// CHECK-NEXT: firrtl-constantifier
// CHECK-NEXT: firrtl-operand0-forwarder
// CHECK-NEXT: firrtl-operand1-forwarder
// CHECK-NEXT: firrtl-operand2-forwarder
// CHECK-NEXT: detach-subaccesses
// CHECK-NEXT: hw-module-externalizer
// CHECK-NEXT: annotation-remover
// CHECK-NEXT: hw-constantifier
// CHECK-NEXT: root-port-pruner
// CHECK-NEXT: hw-operand0-forwarder
// CHECK-NEXT: extmodule-instance-remover
// CHECK-NEXT: cse
// CHECK-NEXT: hw-operand1-forwarder
// CHECK-NEXT: connect-source-operand-0-forwarder
// CHECK-NEXT: canonicalize
// CHECK-NEXT: hw-operand2-forwarder
// CHECK-NEXT: connect-source-operand-1-forwarder
// CHECK-NEXT: operation-pruner
// CHECK-NEXT: connect-source-operand-2-forwarder
// CHECK-EMPTY:
firrtl.circuit "Foo" {
firrtl.module @Foo() {}
}

View File

@ -1,6 +1,6 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test %S.sh --test-arg cat --test-arg "firrtl.module private @Bar" --keep-best=0 --include firrtl-remove-unused-ports --test-must-fail | FileCheck %s
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "firrtl.module private @Bar" --keep-best=0 --include firrtl-remove-unused-ports | FileCheck %s
firrtl.circuit "Foo" {
// CHECK-LABEL: firrtl.module @Foo

View File

@ -1,6 +1,6 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test %S.sh --test-arg firtool --test-arg "error: sink \"x1.x\" not fully initialized" --keep-best=0 --test-must-fail | FileCheck %s
// RUN: circt-reduce %s --test /bin/sh --test-arg -c --test-arg 'firtool "$0" 2>&1 | grep -q "error: sink \"x1.x\" not fully initialized"' --keep-best=0 | FileCheck %s
firrtl.circuit "Foo" {
// CHECK-NOT: firrtl.module @FooFooFoo

View File

@ -0,0 +1,31 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "hw.module @Foo" --keep-best=0 --include hw-constantifier | FileCheck %s
// CHECK-LABEL: hw.module @Foo
hw.module @Foo(%arg0: i32, %arg1: i32) -> (out0: i32, out1: i32) {
// CHECK-NEXT: [[V0:%.+]] = hw.constant 0
// CHECK-NEXT: [[V1:%.+]] = hw.constant 0
%inst.out0, %inst.out1 = hw.instance "inst" @Bar (arg0: %arg0: i32, arg1: %arg1: i32) -> (out0: i32, out1: i32)
// CHECK-NEXT: hw.output [[V0]], [[V1]]
hw.output %inst.out0, %inst.out1 : i32, i32
}
// CHECK: hw.module @Bar
hw.module @Bar(%arg0: i32, %arg1: i32) -> (out0: i32, out1: i32) {
hw.output %arg0, %arg1 : i32, i32
}
// CHECK-LABEL: hw.module @FooFoo
hw.module @FooFoo(%arg0: i32, %arg1: !hw.array<2xi32>) -> (out0: i32, out1: !hw.array<2xi32>) {
// CHECK-NEXT: [[V0:%.+]], [[V1:%.+]] = hw.instance
%inst.out0, %inst.out1 = hw.instance "inst" @FooBar (arg0: %arg0: i32, arg1: %arg1: !hw.array<2xi32>) -> (out0: i32, out1: !hw.array<2xi32>)
// CHECK-NEXT: hw.output [[V0]], [[V1]]
hw.output %inst.out0, %inst.out1 : i32, !hw.array<2xi32>
}
// CHECK: hw.module @FooBar
hw.module @FooBar(%arg0: i32, %arg1: !hw.array<2xi32>) -> (out0: i32, out1: !hw.array<2xi32>) {
hw.output %arg0, %arg1 : i32, !hw.array<2xi32>
}

View File

@ -0,0 +1,18 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "hw.instance" --keep-best=0 --include hw-module-externalizer | FileCheck %s
// CHECK-LABEL: hw.module @Foo
hw.module @Foo(%arg0: i32) -> (out: i32) {
// CHECK-NEXT: hw.instance
%inst.out = hw.instance "inst" @Bar (arg0: %arg0: i32) -> (out: i32)
// CHECK-NEXT: hw.output
hw.output %inst.out : i32
// CHECK-NEXT: }
}
// CHECK-NEXT: hw.module.extern @Bar
// CHECK-NOT: hw.module @Bar
hw.module @Bar(%arg0: i32) -> (out: i32) {
hw.output %arg0 : i32
}

View File

@ -0,0 +1,14 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "hw.module @Foo" --keep-best=0 --include hw-operand0-forwarder | FileCheck %s
// CHECK-LABEL: hw.module @Foo
hw.module @Foo(%arg0: i32, %arg1: i32) -> (out: i32) {
// COM: operand 0 is forwarded here
%0 = comb.and %arg0, %arg1 : i32
// CHECK-NEXT: [[V0:%.+]] = comb.and [[V0]], %arg0 : i32
// COM: cannot forward operand 0 here as it forms a loop
%1 = comb.and %1, %0 : i32
// CHECK-NEXT: hw.output [[V0]]
hw.output %1 : i32
}

View File

@ -0,0 +1,20 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// This test checks that only the reduction patterns of dialects that occur in
// the input file are registered
// RUN: circt-reduce %s --test /bin/cat --list | FileCheck %s
// CHECK: hw-module-externalizer
// CHECK-NEXT: hw-constantifier
// CHECK-NEXT: hw-operand0-forwarder
// CHECK-NEXT: cse
// CHECK-NEXT: hw-operand1-forwarder
// CHECK-NEXT: canonicalize
// CHECK-NEXT: hw-operand2-forwarder
// CHECK-NEXT: operation-pruner
// CHECK-EMPTY:
hw.module @Foo(%in: i1) -> (out: i1) {
hw.output %in : i1
}

View File

@ -0,0 +1,10 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "hw.module @Foo" --keep-best=0 --include canonicalize | FileCheck %s
// CHECK-LABEL: hw.module @Foo
hw.module @Foo(%arg0: i32) -> (out: i32) {
%0 = comb.and %arg0, %arg0 : i32
// CHECK: hw.output %arg0 :
hw.output %0 : i32
}

View File

@ -0,0 +1,9 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "hw.module @Foo" --keep-best=0 --include cse | FileCheck %s
// CHECK-LABEL: hw.module @Foo
hw.module @Foo() {
// CHECK-NOT: hw.constant
%0 = hw.constant 0 : i32
}

View File

@ -0,0 +1,13 @@
// UNSUPPORTED: system-windows
// See https://github.com/llvm/circt/issues/4129
// RUN: circt-reduce %s --test /bin/grep --test-arg -q --test-arg "hw.module @Foo" --keep-best=0 --include operation-pruner | FileCheck %s
// CHECK-LABEL: hw.module @Foo
hw.module @Foo(%arg0: i32) -> (out: i32) {
hw.output %arg0 : i32
}
// CHECK-NOT: hw.module @Bar
hw.module @Bar(%arg0: i32) -> (out: i32) {
hw.output %arg0 : i32
}

View File

@ -1,2 +0,0 @@
#!/bin/sh
! "$1" "$3" 2>&1 | grep "$2" >/dev/null

View File

@ -9,6 +9,10 @@ set(LIBS
${dialect_libs}
${mlir_dialect_libs}
CIRCTArcReductions
CIRCTHWReductions
CIRCTFIRRTLReductions
CIRCTReduceLib
MLIRIR
MLIRParser
MLIRSupport
@ -18,8 +22,6 @@ set(LIBS
add_llvm_tool(circt-reduce
circt-reduce.cpp
Reduction.cpp
Tester.cpp
DEPENDS ${LIBS}
)
target_link_libraries(circt-reduce PRIVATE ${LIBS})

View File

@ -11,9 +11,13 @@
//
//===----------------------------------------------------------------------===//
#include "Reduction.h"
#include "Tester.h"
#include "circt/Dialect/Arc/ArcReductions.h"
#include "circt/Dialect/FIRRTL/FIRRTLReductions.h"
#include "circt/Dialect/HW/HWDialect.h"
#include "circt/Dialect/HW/HWReductions.h"
#include "circt/InitAllDialects.h"
#include "circt/Reduce/GenericReductions.h"
#include "circt/Reduce/Tester.h"
#include "circt/Support/Version.h"
#include "mlir/IR/AsmState.h"
#include "mlir/Parser/Parser.h"
@ -150,24 +154,6 @@ static LogicalResult execute(MLIRContext &context) {
llvm::DenseSet<StringRef> exclusionSet(excludeReductions.begin(),
excludeReductions.end());
// Gather a list of reduction patterns that we should try.
SmallVector<std::unique_ptr<Reduction>> patterns;
createAllReductions(&context, [&](auto reduction) {
auto name = reduction->getName();
if (!inclusionSet.empty() && !inclusionSet.count(name))
return;
if (exclusionSet.count(name))
return;
patterns.push_back(std::move(reduction));
});
// Print the list of patterns.
if (listReductions) {
for (auto &pattern : patterns)
llvm::outs() << pattern->getName() << "\n";
return success();
}
// Parse the input file.
VERBOSE(llvm::errs() << "Reading input\n");
mlir::OwningOpRef<mlir::ModuleOp> module =
@ -175,6 +161,26 @@ static LogicalResult execute(MLIRContext &context) {
if (!module)
return failure();
// Gather a list of reduction patterns that we should try.
ReducePatternSet patterns;
populateGenericReducePatterns(&context, patterns);
ReducePatternInterfaceCollection reducePatternCollection(&context);
reducePatternCollection.populateReducePatterns(patterns);
auto reductionFilter = [&](const Reduction &reduction) {
auto name = reduction.getName();
return (inclusionSet.empty() || inclusionSet.count(name)) &&
!exclusionSet.count(name);
};
patterns.filter(reductionFilter);
patterns.sortByBenefit();
// Print the list of patterns.
if (listReductions) {
for (unsigned i = 0; i < patterns.size(); ++i)
llvm::outs() << patterns[i].getName() << "\n";
return success();
}
// Evaluate the unreduced input.
VERBOSE({
llvm::errs() << "Testing input with `" << testerCommand << "`\n";
@ -206,7 +212,7 @@ static LogicalResult execute(MLIRContext &context) {
// ModuleExternalizer pattern;
BitVector appliedOneShotPatterns(patterns.size(), false);
for (unsigned patternIdx = 0; patternIdx < patterns.size();) {
Reduction &pattern = *patterns[patternIdx];
auto &pattern = patterns[patternIdx];
if (pattern.isOneShot() && appliedOneShotPatterns[patternIdx]) {
LLVM_DEBUG(llvm::dbgs()
<< "Skipping one-shot `" << pattern.getName() << "`\n");
@ -414,6 +420,9 @@ int main(int argc, char **argv) {
// Register all the dialects and create a context to work wtih.
mlir::DialectRegistry registry;
registerAllDialects(registry);
arc::registerReducePatternDialectInterface(registry);
firrtl::registerReducePatternDialectInterface(registry);
hw::registerReducePatternDialectInterface(registry);
mlir::MLIRContext context(registry);
// Do the actual processing and use `exit` to avoid the slow teardown of the