[PyCDE] Include parameters in generator registration and lookup. (#1298)

Previously, different parameterizations of the same module would end
up using the most-recently registered generator and the
parameterization used when it was registered.

This includes the parameterization in the registered mapping, and
looks for that same parameterization when looking up generators. This
is trivially supported since we already store the parameterization as
an attribute on the Operation.

Added a few lines to the polynomial integration test to ensure the
behavior.
This commit is contained in:
mikeurbach 2021-06-17 19:59:48 -06:00 committed by GitHub
parent 99cc3a854d
commit 9dd73b174f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 54 additions and 29 deletions

View File

@ -64,7 +64,8 @@ class module:
# If it's just a module class, we should wrap it immediately
self.mod = _module_base(func_or_class)
_register_generator(self.mod.__name__, "extern_instantiate",
self._instantiate)
self._instantiate,
mlir.ir.DictAttr.get(self.mod._parameters))
return
elif not inspect.isfunction(func_or_class):
raise TypeError("@module got invalid object")
@ -99,7 +100,8 @@ class module:
if self.extern_name:
_register_generator(cls.__name__, "extern_instantiate",
self._instantiate)
self._instantiate,
mlir.ir.DictAttr.get(mod._parameters))
return mod
return self.mod(*args, **kwargs)
@ -275,23 +277,23 @@ def _module_base(cls, params={}):
cls._dont_touch.add(name)
mod._output_ports_lookup = dict(mod._output_ports)
_register_generators(mod)
_register_generators(mod, mlir.ir.DictAttr.get(mod._parameters))
return mod
def _register_generators(modcls):
def _register_generators(modcls, parameters: mlir.ir.Attribute):
"""Scan the class, looking for and registering _Generators."""
for name in dir(modcls):
member = getattr(modcls, name)
if isinstance(member, _Generate):
member.modcls = modcls
_register_generator(modcls.__name__, name, member)
_register_generator(modcls.__name__, name, member, parameters)
def _register_generator(class_name, generator_name, generator):
def _register_generator(class_name, generator_name, generator, parameters):
circt.msft.register_generator(mlir.ir.Context.current,
OPERATION_NAMESPACE + class_name,
generator_name, generator)
generator_name, generator, parameters)
class _Generate:

View File

@ -104,8 +104,16 @@ poly.print()
# CHECK: %example2.y = hw.instance "example2" @PolyComputeForCoeff_62_42_6(%example.y) {parameters = {}} : (i32) -> i32
# CHECK: %example2.y_0 = hw.instance "example2" @PolyComputeForCoeff_1_2_3_4_5(%example.y) {parameters = {}} : (i32) -> i32
# CHECK: %pycde.CoolPolynomialCompute.y = hw.instance "pycde.CoolPolynomialCompute" @supercooldevice(%c23_i32) {coefficients = [4, 42], parameters = {}} : (i32) -> i32
# CHECK: hw.module @PolyComputeForCoeff_62_42_6(%x: i32) -> (%y: i32)
# CHECK: hw.module @PolyComputeForCoeff_1_2_3_4_5(%x: i32) -> (%y: i32)
# CHECK-LABEL: hw.module @PolyComputeForCoeff_62_42_6(%x: i32) -> (%y: i32)
# CHECK: hw.constant 62
# CHECK: hw.constant 42
# CHECK: hw.constant 6
# CHECK-LABEL: hw.module @PolyComputeForCoeff_1_2_3_4_5(%x: i32) -> (%y: i32)
# CHECK: hw.constant 1
# CHECK: hw.constant 2
# CHECK: hw.constant 3
# CHECK: hw.constant 4
# CHECK: hw.constant 5
# CHECK-NOT: hw.module @pycde.PolynomialCompute
print("\n\n=== Verilog ===")

View File

@ -36,7 +36,8 @@ typedef struct {
/// Register a generator callback (function pointer, user data pointer).
void mlirMSFTRegisterGenerator(MlirContext, const char *opName,
const char *generatorName,
mlirMSFTGeneratorCallback cb);
mlirMSFTGeneratorCallback cb,
MlirAttribute parameters);
#ifdef __cplusplus
}

View File

@ -34,7 +34,7 @@ def MSFTDialect : Dialect {
void registerAttributes();
void registerGenerator(StringRef opName, StringRef generatorName,
GeneratorCallback cb);
GeneratorCallback cb, Attribute parameters);
/// Generator details don't need to be exposed.
llvm::ManagedStatic<detail::Generators> generators;

View File

@ -49,11 +49,13 @@ static MlirOperation callPyFunc(MlirOperation op, void *userData) {
}
static void registerGenerator(MlirContext ctxt, std::string opName,
std::string generatorName, py::function cb) {
std::string generatorName, py::function cb,
MlirAttribute parameters) {
// Since we don't have an 'unregister' call, just allocate in forget about it.
py::function *cbPtr = new py::function(cb);
mlirMSFTRegisterGenerator(ctxt, opName.c_str(), generatorName.c_str(),
mlirMSFTGeneratorCallback{&callPyFunc, cbPtr});
mlirMSFTGeneratorCallback{&callPyFunc, cbPtr},
parameters);
}
/// Populate the msft python module.

View File

@ -25,12 +25,14 @@ MlirLogicalResult mlirMSFTExportTcl(MlirModule module,
void mlirMSFTRegisterGenerator(MlirContext cCtxt, const char *opName,
const char *generatorName,
mlirMSFTGeneratorCallback cb) {
mlirMSFTGeneratorCallback cb,
MlirAttribute parameters) {
mlir::MLIRContext *ctxt = unwrap(cCtxt);
MSFTDialect *msft = ctxt->getLoadedDialect<MSFTDialect>();
msft->registerGenerator(llvm::StringRef(opName),
llvm::StringRef(generatorName),
[cb](mlir::Operation *op) {
return unwrap(cb.callback(wrap(op), cb.userData));
});
msft->registerGenerator(
llvm::StringRef(opName), llvm::StringRef(generatorName),
[cb](mlir::Operation *op) {
return unwrap(cb.callback(wrap(op), cb.userData));
},
unwrap(parameters));
}

View File

@ -12,6 +12,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringMap.h"
@ -23,11 +24,12 @@ using GeneratorSet = llvm::SmallSet<StringRef, 8>;
namespace {
/// Holds the set of registered generators for each operation.
class OpGenerator {
llvm::StringMap<GeneratorCallback> generators;
llvm::StringMap<llvm::DenseMap<Attribute, GeneratorCallback>> generators;
public:
void registerOpGenerator(StringRef generatorName, GeneratorCallback cb) {
generators[generatorName] = cb;
void registerOpGenerator(StringRef generatorName, Attribute parameters,
GeneratorCallback cb) {
generators[generatorName][parameters] = cb;
}
LogicalResult runOnOperation(mlir::Operation *op, GeneratorSet generatorSet);
@ -45,21 +47,28 @@ LogicalResult OpGenerator::runOnOperation(mlir::Operation *op,
// Check if any of the generators were selected in the generator set. If more
// than one candidate is present in the generator set, raise an error.
GeneratorCallback gen;
Attribute parameters = op->getAttr("parameters");
for (auto &generatorPair : generators) {
if (generatorSet.contains(generatorPair.first())) {
if (gen)
return op->emitError("multiple generators selected");
gen = generatorPair.second;
auto callbackPair = generatorPair.second.find(parameters);
if (callbackPair != generatorPair.second.end())
gen = callbackPair->second;
}
}
// If no generator was selected by the generator set, and there is just one
// generator, default to using that. Otherwise raise an error.
if (!gen) {
if (generators.size() == 1)
gen = generators.begin()->second;
else
if (generators.size() == 1) {
auto generatorMap = generators.begin()->second;
auto callbackPair = generatorMap.find(parameters);
if (callbackPair != generatorMap.end())
gen = callbackPair->second;
} else {
return op->emitError("unable to select a generator");
}
}
mlir::IRRewriter rewriter(op->getContext());
@ -89,9 +98,10 @@ struct Generators {
} // namespace circt
void MSFTDialect::registerGenerator(StringRef opName, StringRef generatorName,
GeneratorCallback cb) {
generators->registeredOpGenerators[opName].registerOpGenerator(generatorName,
cb);
GeneratorCallback cb,
Attribute parameters) {
generators->registeredOpGenerators[opName].registerOpGenerator(
generatorName, parameters, cb);
}
namespace circt {