diff --git a/include/circt/Scheduling/Problems.h b/include/circt/Scheduling/Problems.h index 9f4b1562f8..8d6f149423 100644 --- a/include/circt/Scheduling/Problems.h +++ b/include/circt/Scheduling/Problems.h @@ -26,6 +26,17 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/SetVector.h" +#define DEFINE_FACTORY_METHOD(ProblemClass) \ +protected: \ + ProblemClass() {} \ + \ +public: \ + static ProblemClass get(Operation *containingOp) { \ + ProblemClass prob; \ + prob.setContainingOp(containingOp); \ + return prob; \ + } + namespace circt { namespace scheduling { @@ -73,9 +84,9 @@ namespace scheduling { /// for each registered operation, and the precedence constraints as modeled by /// the dependences are satisfied. class Problem { + DEFINE_FACTORY_METHOD(Problem) + public: - /// Initialize a scheduling problem corresponding to \p containingOp. - explicit Problem(Operation *containingOp) : containingOp(containingOp) {} virtual ~Problem() = default; friend detail::DependenceIterator; @@ -157,6 +168,8 @@ public: public: /// Return the operation containing this problem, e.g. to emit diagnostics. Operation *getContainingOp() { return containingOp; } + /// Set the operation containing this problem, e.g. to emit diagnostics. + void setContainingOp(Operation *op) { containingOp = op; } /// Return true if \p op is part of this problem. bool hasOperation(Operation *op) { return operations.contains(op); } @@ -235,13 +248,13 @@ public: /// interval, in which the execution of multiple iterations/samples/etc. may /// overlap. class CyclicProblem : public virtual Problem { + DEFINE_FACTORY_METHOD(CyclicProblem) + private: DependenceProperty distance; ProblemProperty initiationInterval; public: - explicit CyclicProblem(Operation *containingOp) : Problem(containingOp) {} - /// The distance determines whether a dependence has to be satisfied in the /// same iteration (distance=0 or not set), or distance-many iterations later. Optional getDistance(Dependence dep) { @@ -280,13 +293,12 @@ public: /// exceed the operator type's limit. These constraints do not apply to operator /// types without a limit (not set, or 0). class SharedOperatorsProblem : public virtual Problem { + DEFINE_FACTORY_METHOD(SharedOperatorsProblem) + private: OperatorTypeProperty limit; public: - explicit SharedOperatorsProblem(Operation *containingOp) - : Problem(containingOp) {} - /// The limit is the maximum number of operations using \p opr that are /// allowed to start in the same time step. Optional getLimit(OperatorType opr) { return limit.lookup(opr); } @@ -315,10 +327,7 @@ public: /// not exceed the operator type's limit. class ModuloProblem : public virtual CyclicProblem, public virtual SharedOperatorsProblem { -public: - explicit ModuloProblem(Operation *containingOp) - : Problem(containingOp), CyclicProblem(containingOp), - SharedOperatorsProblem(containingOp) {} + DEFINE_FACTORY_METHOD(ModuloProblem) protected: /// \p opr is not oversubscribed in any congruence class modulo II. @@ -331,4 +340,6 @@ public: } // namespace scheduling } // namespace circt +#undef DEFINE_FACTORY_METHOD + #endif // CIRCT_SCHEDULING_PROBLEMS_H diff --git a/lib/Analysis/SchedulingAnalysis.cpp b/lib/Analysis/SchedulingAnalysis.cpp index 7c7e7bc9bd..59938b7ad3 100644 --- a/lib/Analysis/SchedulingAnalysis.cpp +++ b/lib/Analysis/SchedulingAnalysis.cpp @@ -43,7 +43,7 @@ circt::analysis::CyclicSchedulingAnalysis::CyclicSchedulingAnalysis( void circt::analysis::CyclicSchedulingAnalysis::analyzeForOp( AffineForOp forOp, MemoryDependenceAnalysis memoryAnalysis) { // Create a cyclic scheduling problem. - CyclicProblem problem(forOp); + CyclicProblem problem = CyclicProblem::get(forOp); // Insert memory dependences into the problem. forOp.getBody()->walk([&](Operation *op) { diff --git a/lib/Scheduling/TestPasses.cpp b/lib/Scheduling/TestPasses.cpp index fb171ce76f..fe24b9c097 100644 --- a/lib/Scheduling/TestPasses.cpp +++ b/lib/Scheduling/TestPasses.cpp @@ -137,7 +137,7 @@ struct TestProblemPass : public PassWrapper { void TestProblemPass::runOnFunction() { auto func = getFunction(); - Problem prob(func); + auto prob = Problem::get(func); constructProblem(prob, func); if (failed(prob.check())) { @@ -174,7 +174,7 @@ struct TestCyclicProblemPass void TestCyclicProblemPass::runOnFunction() { auto func = getFunction(); - CyclicProblem prob(func); + auto prob = CyclicProblem::get(func); constructProblem(prob, func); constructCyclicProblem(prob, func); @@ -219,7 +219,7 @@ struct TestSharedOperatorsProblemPass void TestSharedOperatorsProblemPass::runOnFunction() { auto func = getFunction(); - SharedOperatorsProblem prob(func); + auto prob = SharedOperatorsProblem::get(func); constructProblem(prob, func); constructSharedOperatorsProblem(prob, func); @@ -257,7 +257,7 @@ struct TestModuloProblemPass void TestModuloProblemPass::runOnFunction() { auto func = getFunction(); - ModuloProblem prob(func); + auto prob = ModuloProblem::get(func); constructProblem(prob, func); constructCyclicProblem(prob, func); constructSharedOperatorsProblem(prob, func); @@ -300,7 +300,7 @@ struct TestASAPSchedulerPass void TestASAPSchedulerPass::runOnFunction() { auto func = getFunction(); - Problem prob(func); + auto prob = Problem::get(func); constructProblem(prob, func); assert(succeeded(prob.check())); @@ -342,7 +342,7 @@ void TestSimplexSchedulerPass::runOnFunction() { OpBuilder builder(func.getContext()); if (problemToTest == "Problem") { - Problem prob(func); + auto prob = Problem::get(func); constructProblem(prob, func); assert(succeeded(prob.check())); @@ -361,7 +361,7 @@ void TestSimplexSchedulerPass::runOnFunction() { } if (problemToTest == "CyclicProblem") { - CyclicProblem prob(func); + auto prob = CyclicProblem::get(func); constructProblem(prob, func); constructCyclicProblem(prob, func); assert(succeeded(prob.check())); @@ -383,7 +383,7 @@ void TestSimplexSchedulerPass::runOnFunction() { } if (problemToTest == "SharedOperatorsProblem") { - SharedOperatorsProblem prob(func); + auto prob = SharedOperatorsProblem::get(func); constructProblem(prob, func); constructSharedOperatorsProblem(prob, func); assert(succeeded(prob.check()));