diff --git a/frontends/PyCDE/src/pycde/module.py b/frontends/PyCDE/src/pycde/module.py index 0d884fd33c..bc9646ec99 100644 --- a/frontends/PyCDE/src/pycde/module.py +++ b/frontends/PyCDE/src/pycde/module.py @@ -64,7 +64,8 @@ class module: # If it's just a module class, we should wrap it immediately self.mod = _module_base(func_or_class) _register_generator(self.mod.__name__, "extern_instantiate", - self._instantiate) + self._instantiate, + mlir.ir.DictAttr.get(self.mod._parameters)) return elif not inspect.isfunction(func_or_class): raise TypeError("@module got invalid object") @@ -99,7 +100,8 @@ class module: if self.extern_name: _register_generator(cls.__name__, "extern_instantiate", - self._instantiate) + self._instantiate, + mlir.ir.DictAttr.get(mod._parameters)) return mod return self.mod(*args, **kwargs) @@ -275,23 +277,23 @@ def _module_base(cls, params={}): cls._dont_touch.add(name) mod._output_ports_lookup = dict(mod._output_ports) - _register_generators(mod) + _register_generators(mod, mlir.ir.DictAttr.get(mod._parameters)) return mod -def _register_generators(modcls): +def _register_generators(modcls, parameters: mlir.ir.Attribute): """Scan the class, looking for and registering _Generators.""" for name in dir(modcls): member = getattr(modcls, name) if isinstance(member, _Generate): member.modcls = modcls - _register_generator(modcls.__name__, name, member) + _register_generator(modcls.__name__, name, member, parameters) -def _register_generator(class_name, generator_name, generator): +def _register_generator(class_name, generator_name, generator, parameters): circt.msft.register_generator(mlir.ir.Context.current, OPERATION_NAMESPACE + class_name, - generator_name, generator) + generator_name, generator, parameters) class _Generate: diff --git a/frontends/PyCDE/test/polynomial.py b/frontends/PyCDE/test/polynomial.py index 3409ab9ea0..326a3672b0 100755 --- a/frontends/PyCDE/test/polynomial.py +++ b/frontends/PyCDE/test/polynomial.py @@ -104,8 +104,16 @@ poly.print() # CHECK: %example2.y = hw.instance "example2" @PolyComputeForCoeff_62_42_6(%example.y) {parameters = {}} : (i32) -> i32 # CHECK: %example2.y_0 = hw.instance "example2" @PolyComputeForCoeff_1_2_3_4_5(%example.y) {parameters = {}} : (i32) -> i32 # CHECK: %pycde.CoolPolynomialCompute.y = hw.instance "pycde.CoolPolynomialCompute" @supercooldevice(%c23_i32) {coefficients = [4, 42], parameters = {}} : (i32) -> i32 -# CHECK: hw.module @PolyComputeForCoeff_62_42_6(%x: i32) -> (%y: i32) -# CHECK: hw.module @PolyComputeForCoeff_1_2_3_4_5(%x: i32) -> (%y: i32) +# CHECK-LABEL: hw.module @PolyComputeForCoeff_62_42_6(%x: i32) -> (%y: i32) +# CHECK: hw.constant 62 +# CHECK: hw.constant 42 +# CHECK: hw.constant 6 +# CHECK-LABEL: hw.module @PolyComputeForCoeff_1_2_3_4_5(%x: i32) -> (%y: i32) +# CHECK: hw.constant 1 +# CHECK: hw.constant 2 +# CHECK: hw.constant 3 +# CHECK: hw.constant 4 +# CHECK: hw.constant 5 # CHECK-NOT: hw.module @pycde.PolynomialCompute print("\n\n=== Verilog ===") diff --git a/include/circt-c/Dialect/MSFT.h b/include/circt-c/Dialect/MSFT.h index a1afbd13c0..192bd854ec 100644 --- a/include/circt-c/Dialect/MSFT.h +++ b/include/circt-c/Dialect/MSFT.h @@ -36,7 +36,8 @@ typedef struct { /// Register a generator callback (function pointer, user data pointer). void mlirMSFTRegisterGenerator(MlirContext, const char *opName, const char *generatorName, - mlirMSFTGeneratorCallback cb); + mlirMSFTGeneratorCallback cb, + MlirAttribute parameters); #ifdef __cplusplus } diff --git a/include/circt/Dialect/MSFT/MSFT.td b/include/circt/Dialect/MSFT/MSFT.td index 1fcf1358dd..80b3a2851f 100644 --- a/include/circt/Dialect/MSFT/MSFT.td +++ b/include/circt/Dialect/MSFT/MSFT.td @@ -34,7 +34,7 @@ def MSFTDialect : Dialect { void registerAttributes(); void registerGenerator(StringRef opName, StringRef generatorName, - GeneratorCallback cb); + GeneratorCallback cb, Attribute parameters); /// Generator details don't need to be exposed. llvm::ManagedStatic generators; diff --git a/lib/Bindings/Python/MSFTModule.cpp b/lib/Bindings/Python/MSFTModule.cpp index 5df0f2b5fd..0e54983161 100644 --- a/lib/Bindings/Python/MSFTModule.cpp +++ b/lib/Bindings/Python/MSFTModule.cpp @@ -49,11 +49,13 @@ static MlirOperation callPyFunc(MlirOperation op, void *userData) { } static void registerGenerator(MlirContext ctxt, std::string opName, - std::string generatorName, py::function cb) { + std::string generatorName, py::function cb, + MlirAttribute parameters) { // Since we don't have an 'unregister' call, just allocate in forget about it. py::function *cbPtr = new py::function(cb); mlirMSFTRegisterGenerator(ctxt, opName.c_str(), generatorName.c_str(), - mlirMSFTGeneratorCallback{&callPyFunc, cbPtr}); + mlirMSFTGeneratorCallback{&callPyFunc, cbPtr}, + parameters); } /// Populate the msft python module. diff --git a/lib/CAPI/Dialect/MSFT.cpp b/lib/CAPI/Dialect/MSFT.cpp index 25e35c0deb..d83510bc38 100644 --- a/lib/CAPI/Dialect/MSFT.cpp +++ b/lib/CAPI/Dialect/MSFT.cpp @@ -25,12 +25,14 @@ MlirLogicalResult mlirMSFTExportTcl(MlirModule module, void mlirMSFTRegisterGenerator(MlirContext cCtxt, const char *opName, const char *generatorName, - mlirMSFTGeneratorCallback cb) { + mlirMSFTGeneratorCallback cb, + MlirAttribute parameters) { mlir::MLIRContext *ctxt = unwrap(cCtxt); MSFTDialect *msft = ctxt->getLoadedDialect(); - msft->registerGenerator(llvm::StringRef(opName), - llvm::StringRef(generatorName), - [cb](mlir::Operation *op) { - return unwrap(cb.callback(wrap(op), cb.userData)); - }); + msft->registerGenerator( + llvm::StringRef(opName), llvm::StringRef(generatorName), + [cb](mlir::Operation *op) { + return unwrap(cb.callback(wrap(op), cb.userData)); + }, + unwrap(parameters)); } diff --git a/lib/Dialect/MSFT/MSFTGenerator.cpp b/lib/Dialect/MSFT/MSFTGenerator.cpp index 066b3ce7a4..0981561c8b 100644 --- a/lib/Dialect/MSFT/MSFTGenerator.cpp +++ b/lib/Dialect/MSFT/MSFTGenerator.cpp @@ -12,6 +12,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" @@ -23,11 +24,12 @@ using GeneratorSet = llvm::SmallSet; namespace { /// Holds the set of registered generators for each operation. class OpGenerator { - llvm::StringMap generators; + llvm::StringMap> generators; public: - void registerOpGenerator(StringRef generatorName, GeneratorCallback cb) { - generators[generatorName] = cb; + void registerOpGenerator(StringRef generatorName, Attribute parameters, + GeneratorCallback cb) { + generators[generatorName][parameters] = cb; } LogicalResult runOnOperation(mlir::Operation *op, GeneratorSet generatorSet); @@ -45,21 +47,28 @@ LogicalResult OpGenerator::runOnOperation(mlir::Operation *op, // Check if any of the generators were selected in the generator set. If more // than one candidate is present in the generator set, raise an error. GeneratorCallback gen; + Attribute parameters = op->getAttr("parameters"); for (auto &generatorPair : generators) { if (generatorSet.contains(generatorPair.first())) { if (gen) return op->emitError("multiple generators selected"); - gen = generatorPair.second; + auto callbackPair = generatorPair.second.find(parameters); + if (callbackPair != generatorPair.second.end()) + gen = callbackPair->second; } } // If no generator was selected by the generator set, and there is just one // generator, default to using that. Otherwise raise an error. if (!gen) { - if (generators.size() == 1) - gen = generators.begin()->second; - else + if (generators.size() == 1) { + auto generatorMap = generators.begin()->second; + auto callbackPair = generatorMap.find(parameters); + if (callbackPair != generatorMap.end()) + gen = callbackPair->second; + } else { return op->emitError("unable to select a generator"); + } } mlir::IRRewriter rewriter(op->getContext()); @@ -89,9 +98,10 @@ struct Generators { } // namespace circt void MSFTDialect::registerGenerator(StringRef opName, StringRef generatorName, - GeneratorCallback cb) { - generators->registeredOpGenerators[opName].registerOpGenerator(generatorName, - cb); + GeneratorCallback cb, + Attribute parameters) { + generators->registeredOpGenerators[opName].registerOpGenerator( + generatorName, parameters, cb); } namespace circt {