[Scheduling] Get problem instances via factory methods. (#2227)

Replaces the public single-argument constructors that initialize the containingOp in the Problem hierarchy with factory methods.
This commit is contained in:
Julian Oppermann 2021-11-30 09:53:49 +01:00 committed by GitHub
parent 4edc6f2260
commit 926bfbad2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 20 deletions

View File

@ -26,6 +26,17 @@
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SetVector.h"
#define DEFINE_FACTORY_METHOD(ProblemClass) \
protected: \
ProblemClass() {} \
\
public: \
static ProblemClass get(Operation *containingOp) { \
ProblemClass prob; \
prob.setContainingOp(containingOp); \
return prob; \
}
namespace circt {
namespace scheduling {
@ -73,9 +84,9 @@ namespace scheduling {
/// for each registered operation, and the precedence constraints as modeled by
/// the dependences are satisfied.
class Problem {
DEFINE_FACTORY_METHOD(Problem)
public:
/// Initialize a scheduling problem corresponding to \p containingOp.
explicit Problem(Operation *containingOp) : containingOp(containingOp) {}
virtual ~Problem() = default;
friend detail::DependenceIterator;
@ -157,6 +168,8 @@ public:
public:
/// Return the operation containing this problem, e.g. to emit diagnostics.
Operation *getContainingOp() { return containingOp; }
/// Set the operation containing this problem, e.g. to emit diagnostics.
void setContainingOp(Operation *op) { containingOp = op; }
/// Return true if \p op is part of this problem.
bool hasOperation(Operation *op) { return operations.contains(op); }
@ -235,13 +248,13 @@ public:
/// interval, in which the execution of multiple iterations/samples/etc. may
/// overlap.
class CyclicProblem : public virtual Problem {
DEFINE_FACTORY_METHOD(CyclicProblem)
private:
DependenceProperty<unsigned> distance;
ProblemProperty<unsigned> initiationInterval;
public:
explicit CyclicProblem(Operation *containingOp) : Problem(containingOp) {}
/// The distance determines whether a dependence has to be satisfied in the
/// same iteration (distance=0 or not set), or distance-many iterations later.
Optional<unsigned> getDistance(Dependence dep) {
@ -280,13 +293,12 @@ public:
/// exceed the operator type's limit. These constraints do not apply to operator
/// types without a limit (not set, or 0).
class SharedOperatorsProblem : public virtual Problem {
DEFINE_FACTORY_METHOD(SharedOperatorsProblem)
private:
OperatorTypeProperty<unsigned> limit;
public:
explicit SharedOperatorsProblem(Operation *containingOp)
: Problem(containingOp) {}
/// The limit is the maximum number of operations using \p opr that are
/// allowed to start in the same time step.
Optional<unsigned> getLimit(OperatorType opr) { return limit.lookup(opr); }
@ -315,10 +327,7 @@ public:
/// not exceed the operator type's limit.
class ModuloProblem : public virtual CyclicProblem,
public virtual SharedOperatorsProblem {
public:
explicit ModuloProblem(Operation *containingOp)
: Problem(containingOp), CyclicProblem(containingOp),
SharedOperatorsProblem(containingOp) {}
DEFINE_FACTORY_METHOD(ModuloProblem)
protected:
/// \p opr is not oversubscribed in any congruence class modulo II.
@ -331,4 +340,6 @@ public:
} // namespace scheduling
} // namespace circt
#undef DEFINE_FACTORY_METHOD
#endif // CIRCT_SCHEDULING_PROBLEMS_H

View File

@ -43,7 +43,7 @@ circt::analysis::CyclicSchedulingAnalysis::CyclicSchedulingAnalysis(
void circt::analysis::CyclicSchedulingAnalysis::analyzeForOp(
AffineForOp forOp, MemoryDependenceAnalysis memoryAnalysis) {
// Create a cyclic scheduling problem.
CyclicProblem problem(forOp);
CyclicProblem problem = CyclicProblem::get(forOp);
// Insert memory dependences into the problem.
forOp.getBody()->walk([&](Operation *op) {

View File

@ -137,7 +137,7 @@ struct TestProblemPass : public PassWrapper<TestProblemPass, FunctionPass> {
void TestProblemPass::runOnFunction() {
auto func = getFunction();
Problem prob(func);
auto prob = Problem::get(func);
constructProblem(prob, func);
if (failed(prob.check())) {
@ -174,7 +174,7 @@ struct TestCyclicProblemPass
void TestCyclicProblemPass::runOnFunction() {
auto func = getFunction();
CyclicProblem prob(func);
auto prob = CyclicProblem::get(func);
constructProblem(prob, func);
constructCyclicProblem(prob, func);
@ -219,7 +219,7 @@ struct TestSharedOperatorsProblemPass
void TestSharedOperatorsProblemPass::runOnFunction() {
auto func = getFunction();
SharedOperatorsProblem prob(func);
auto prob = SharedOperatorsProblem::get(func);
constructProblem(prob, func);
constructSharedOperatorsProblem(prob, func);
@ -257,7 +257,7 @@ struct TestModuloProblemPass
void TestModuloProblemPass::runOnFunction() {
auto func = getFunction();
ModuloProblem prob(func);
auto prob = ModuloProblem::get(func);
constructProblem(prob, func);
constructCyclicProblem(prob, func);
constructSharedOperatorsProblem(prob, func);
@ -300,7 +300,7 @@ struct TestASAPSchedulerPass
void TestASAPSchedulerPass::runOnFunction() {
auto func = getFunction();
Problem prob(func);
auto prob = Problem::get(func);
constructProblem(prob, func);
assert(succeeded(prob.check()));
@ -342,7 +342,7 @@ void TestSimplexSchedulerPass::runOnFunction() {
OpBuilder builder(func.getContext());
if (problemToTest == "Problem") {
Problem prob(func);
auto prob = Problem::get(func);
constructProblem(prob, func);
assert(succeeded(prob.check()));
@ -361,7 +361,7 @@ void TestSimplexSchedulerPass::runOnFunction() {
}
if (problemToTest == "CyclicProblem") {
CyclicProblem prob(func);
auto prob = CyclicProblem::get(func);
constructProblem(prob, func);
constructCyclicProblem(prob, func);
assert(succeeded(prob.check()));
@ -383,7 +383,7 @@ void TestSimplexSchedulerPass::runOnFunction() {
}
if (problemToTest == "SharedOperatorsProblem") {
SharedOperatorsProblem prob(func);
auto prob = SharedOperatorsProblem::get(func);
constructProblem(prob, func);
constructSharedOperatorsProblem(prob, func);
assert(succeeded(prob.check()));