mirror of https://github.com/llvm/circt.git
[PyCDE] Include parameters in generator registration and lookup. (#1298)
Previously, different parameterizations of the same module would end up using the most-recently registered generator and the parameterization used when it was registered. This includes the parameterization in the registered mapping, and looks for that same parameterization when looking up generators. This is trivially supported since we already store the parameterization as an attribute on the Operation. Added a few lines to the polynomial integration test to ensure the behavior.
This commit is contained in:
parent
99cc3a854d
commit
9dd73b174f
|
@ -64,7 +64,8 @@ class module:
|
|||
# If it's just a module class, we should wrap it immediately
|
||||
self.mod = _module_base(func_or_class)
|
||||
_register_generator(self.mod.__name__, "extern_instantiate",
|
||||
self._instantiate)
|
||||
self._instantiate,
|
||||
mlir.ir.DictAttr.get(self.mod._parameters))
|
||||
return
|
||||
elif not inspect.isfunction(func_or_class):
|
||||
raise TypeError("@module got invalid object")
|
||||
|
@ -99,7 +100,8 @@ class module:
|
|||
|
||||
if self.extern_name:
|
||||
_register_generator(cls.__name__, "extern_instantiate",
|
||||
self._instantiate)
|
||||
self._instantiate,
|
||||
mlir.ir.DictAttr.get(mod._parameters))
|
||||
return mod
|
||||
|
||||
return self.mod(*args, **kwargs)
|
||||
|
@ -275,23 +277,23 @@ def _module_base(cls, params={}):
|
|||
cls._dont_touch.add(name)
|
||||
mod._output_ports_lookup = dict(mod._output_ports)
|
||||
|
||||
_register_generators(mod)
|
||||
_register_generators(mod, mlir.ir.DictAttr.get(mod._parameters))
|
||||
return mod
|
||||
|
||||
|
||||
def _register_generators(modcls):
|
||||
def _register_generators(modcls, parameters: mlir.ir.Attribute):
|
||||
"""Scan the class, looking for and registering _Generators."""
|
||||
for name in dir(modcls):
|
||||
member = getattr(modcls, name)
|
||||
if isinstance(member, _Generate):
|
||||
member.modcls = modcls
|
||||
_register_generator(modcls.__name__, name, member)
|
||||
_register_generator(modcls.__name__, name, member, parameters)
|
||||
|
||||
|
||||
def _register_generator(class_name, generator_name, generator):
|
||||
def _register_generator(class_name, generator_name, generator, parameters):
|
||||
circt.msft.register_generator(mlir.ir.Context.current,
|
||||
OPERATION_NAMESPACE + class_name,
|
||||
generator_name, generator)
|
||||
generator_name, generator, parameters)
|
||||
|
||||
|
||||
class _Generate:
|
||||
|
|
|
@ -104,8 +104,16 @@ poly.print()
|
|||
# CHECK: %example2.y = hw.instance "example2" @PolyComputeForCoeff_62_42_6(%example.y) {parameters = {}} : (i32) -> i32
|
||||
# CHECK: %example2.y_0 = hw.instance "example2" @PolyComputeForCoeff_1_2_3_4_5(%example.y) {parameters = {}} : (i32) -> i32
|
||||
# CHECK: %pycde.CoolPolynomialCompute.y = hw.instance "pycde.CoolPolynomialCompute" @supercooldevice(%c23_i32) {coefficients = [4, 42], parameters = {}} : (i32) -> i32
|
||||
# CHECK: hw.module @PolyComputeForCoeff_62_42_6(%x: i32) -> (%y: i32)
|
||||
# CHECK: hw.module @PolyComputeForCoeff_1_2_3_4_5(%x: i32) -> (%y: i32)
|
||||
# CHECK-LABEL: hw.module @PolyComputeForCoeff_62_42_6(%x: i32) -> (%y: i32)
|
||||
# CHECK: hw.constant 62
|
||||
# CHECK: hw.constant 42
|
||||
# CHECK: hw.constant 6
|
||||
# CHECK-LABEL: hw.module @PolyComputeForCoeff_1_2_3_4_5(%x: i32) -> (%y: i32)
|
||||
# CHECK: hw.constant 1
|
||||
# CHECK: hw.constant 2
|
||||
# CHECK: hw.constant 3
|
||||
# CHECK: hw.constant 4
|
||||
# CHECK: hw.constant 5
|
||||
# CHECK-NOT: hw.module @pycde.PolynomialCompute
|
||||
|
||||
print("\n\n=== Verilog ===")
|
||||
|
|
|
@ -36,7 +36,8 @@ typedef struct {
|
|||
/// Register a generator callback (function pointer, user data pointer).
|
||||
void mlirMSFTRegisterGenerator(MlirContext, const char *opName,
|
||||
const char *generatorName,
|
||||
mlirMSFTGeneratorCallback cb);
|
||||
mlirMSFTGeneratorCallback cb,
|
||||
MlirAttribute parameters);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ def MSFTDialect : Dialect {
|
|||
void registerAttributes();
|
||||
|
||||
void registerGenerator(StringRef opName, StringRef generatorName,
|
||||
GeneratorCallback cb);
|
||||
GeneratorCallback cb, Attribute parameters);
|
||||
|
||||
/// Generator details don't need to be exposed.
|
||||
llvm::ManagedStatic<detail::Generators> generators;
|
||||
|
|
|
@ -49,11 +49,13 @@ static MlirOperation callPyFunc(MlirOperation op, void *userData) {
|
|||
}
|
||||
|
||||
static void registerGenerator(MlirContext ctxt, std::string opName,
|
||||
std::string generatorName, py::function cb) {
|
||||
std::string generatorName, py::function cb,
|
||||
MlirAttribute parameters) {
|
||||
// Since we don't have an 'unregister' call, just allocate in forget about it.
|
||||
py::function *cbPtr = new py::function(cb);
|
||||
mlirMSFTRegisterGenerator(ctxt, opName.c_str(), generatorName.c_str(),
|
||||
mlirMSFTGeneratorCallback{&callPyFunc, cbPtr});
|
||||
mlirMSFTGeneratorCallback{&callPyFunc, cbPtr},
|
||||
parameters);
|
||||
}
|
||||
|
||||
/// Populate the msft python module.
|
||||
|
|
|
@ -25,12 +25,14 @@ MlirLogicalResult mlirMSFTExportTcl(MlirModule module,
|
|||
|
||||
void mlirMSFTRegisterGenerator(MlirContext cCtxt, const char *opName,
|
||||
const char *generatorName,
|
||||
mlirMSFTGeneratorCallback cb) {
|
||||
mlirMSFTGeneratorCallback cb,
|
||||
MlirAttribute parameters) {
|
||||
mlir::MLIRContext *ctxt = unwrap(cCtxt);
|
||||
MSFTDialect *msft = ctxt->getLoadedDialect<MSFTDialect>();
|
||||
msft->registerGenerator(llvm::StringRef(opName),
|
||||
llvm::StringRef(generatorName),
|
||||
msft->registerGenerator(
|
||||
llvm::StringRef(opName), llvm::StringRef(generatorName),
|
||||
[cb](mlir::Operation *op) {
|
||||
return unwrap(cb.callback(wrap(op), cb.userData));
|
||||
});
|
||||
},
|
||||
unwrap(parameters));
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
|
@ -23,11 +24,12 @@ using GeneratorSet = llvm::SmallSet<StringRef, 8>;
|
|||
namespace {
|
||||
/// Holds the set of registered generators for each operation.
|
||||
class OpGenerator {
|
||||
llvm::StringMap<GeneratorCallback> generators;
|
||||
llvm::StringMap<llvm::DenseMap<Attribute, GeneratorCallback>> generators;
|
||||
|
||||
public:
|
||||
void registerOpGenerator(StringRef generatorName, GeneratorCallback cb) {
|
||||
generators[generatorName] = cb;
|
||||
void registerOpGenerator(StringRef generatorName, Attribute parameters,
|
||||
GeneratorCallback cb) {
|
||||
generators[generatorName][parameters] = cb;
|
||||
}
|
||||
|
||||
LogicalResult runOnOperation(mlir::Operation *op, GeneratorSet generatorSet);
|
||||
|
@ -45,22 +47,29 @@ LogicalResult OpGenerator::runOnOperation(mlir::Operation *op,
|
|||
// Check if any of the generators were selected in the generator set. If more
|
||||
// than one candidate is present in the generator set, raise an error.
|
||||
GeneratorCallback gen;
|
||||
Attribute parameters = op->getAttr("parameters");
|
||||
for (auto &generatorPair : generators) {
|
||||
if (generatorSet.contains(generatorPair.first())) {
|
||||
if (gen)
|
||||
return op->emitError("multiple generators selected");
|
||||
gen = generatorPair.second;
|
||||
auto callbackPair = generatorPair.second.find(parameters);
|
||||
if (callbackPair != generatorPair.second.end())
|
||||
gen = callbackPair->second;
|
||||
}
|
||||
}
|
||||
|
||||
// If no generator was selected by the generator set, and there is just one
|
||||
// generator, default to using that. Otherwise raise an error.
|
||||
if (!gen) {
|
||||
if (generators.size() == 1)
|
||||
gen = generators.begin()->second;
|
||||
else
|
||||
if (generators.size() == 1) {
|
||||
auto generatorMap = generators.begin()->second;
|
||||
auto callbackPair = generatorMap.find(parameters);
|
||||
if (callbackPair != generatorMap.end())
|
||||
gen = callbackPair->second;
|
||||
} else {
|
||||
return op->emitError("unable to select a generator");
|
||||
}
|
||||
}
|
||||
|
||||
mlir::IRRewriter rewriter(op->getContext());
|
||||
Operation *replacement = gen(op);
|
||||
|
@ -89,9 +98,10 @@ struct Generators {
|
|||
} // namespace circt
|
||||
|
||||
void MSFTDialect::registerGenerator(StringRef opName, StringRef generatorName,
|
||||
GeneratorCallback cb) {
|
||||
generators->registeredOpGenerators[opName].registerOpGenerator(generatorName,
|
||||
cb);
|
||||
GeneratorCallback cb,
|
||||
Attribute parameters) {
|
||||
generators->registeredOpGenerators[opName].registerOpGenerator(
|
||||
generatorName, parameters, cb);
|
||||
}
|
||||
|
||||
namespace circt {
|
||||
|
|
Loading…
Reference in New Issue