[HW] Relax parameter evaluation to allow resolution to passed in parameters. (#3159)

This allows a parametric module to instantiate nested parametric modules using
its own parameters.

This change required reworking the HWSpecialize pass, where parametric modules
defer specialization of instantiated parametric modules until after the parent
modules themselves are specialized.
This commit is contained in:
Daniel Resnick 2022-05-23 08:37:16 -07:00 committed by GitHub
parent 1bc5e993a2
commit 498798cbc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 248 additions and 66 deletions

View File

@ -33,9 +33,9 @@ mlir::FailureOr<mlir::Type> evaluateParametricType(mlir::Location loc,
/// Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set
/// of provided parameter values.
mlir::FailureOr<mlir::APInt> evaluateParametricAttr(mlir::Location loc,
mlir::ArrayAttr parameters,
mlir::Attribute paramAttr);
mlir::FailureOr<mlir::Attribute>
evaluateParametricAttr(mlir::Location loc, mlir::ArrayAttr parameters,
mlir::Attribute paramAttr);
/// Returns true if any part of t is parametric.
bool isParametricType(mlir::Type t);

View File

@ -750,8 +750,9 @@ replaceDeclRefInExpr(Location loc,
return {};
}
FailureOr<APInt> hw::evaluateParametricAttr(Location loc, ArrayAttr parameters,
Attribute paramAttr) {
FailureOr<Attribute> hw::evaluateParametricAttr(Location loc,
ArrayAttr parameters,
Attribute paramAttr) {
// Create a map of the provided parameters for faster lookup.
std::map<std::string, Attribute> parameterMap;
for (auto param : parameters) {
@ -767,38 +768,36 @@ FailureOr<APInt> hw::evaluateParametricAttr(Location loc, ArrayAttr parameters,
paramAttr = paramAttrRes.getValue();
// Then, evaluate the parametric attribute.
if (auto intAttr = paramAttr.dyn_cast<IntegerAttr>())
return intAttr.getValue();
if (paramAttr.isa<IntegerAttr>() || paramAttr.isa<hw::ParamDeclRefAttr>())
return paramAttr;
if (auto paramExprAttr = paramAttr.dyn_cast<hw::ParamExprAttr>()) {
// Since any ParamDeclRefAttr was replaced within the expression, the
// expression should be able to be fully canonicalized to a constant. We do
// this through the existing ParamExprAttr canonicalizer.
auto resAttr = ParamExprAttr::get(paramExprAttr.getOpcode(),
paramExprAttr.getOperands());
auto resIntAttr = resAttr.dyn_cast<IntegerAttr>();
if (!resIntAttr)
return emitError(loc,
"Could not evaluate the expression to a constant value")
.attachNote()
<< "This means that some parts of the expression did not resolve "
"to a constant";
return resIntAttr.getValue();
// Since any ParamDeclRefAttr was replaced within the expression,
// we re-evaluate the expression through the existing ParamExprAttr
// canonicalizer.
return ParamExprAttr::get(paramExprAttr.getOpcode(),
paramExprAttr.getOperands());
}
llvm_unreachable("Unhandled parametric attribute");
return APInt();
return Attribute();
}
FailureOr<Type> hw::evaluateParametricType(Location loc, ArrayAttr parameters,
Type type) {
return llvm::TypeSwitch<Type, Type>(type)
.Case<hw::IntType>([&](hw::IntType t) -> FailureOr<Type> {
auto attrValue = evaluateParametricAttr(loc, parameters, t.getWidth());
if (failed(attrValue))
auto evaluatedWidth =
evaluateParametricAttr(loc, parameters, t.getWidth());
if (failed(evaluatedWidth))
return {failure()};
return {IntegerType::get(type.getContext(),
attrValue.getValue().getSExtValue())};
// If the width was evaluated to a constant, return an `IntegerType`
if (auto intAttr = evaluatedWidth->dyn_cast<IntegerAttr>())
return {IntegerType::get(type.getContext(),
intAttr.getValue().getSExtValue())};
// Otherwise parameter references are still involved
return hw::IntType::get(evaluatedWidth.getValue());
})
.Case<hw::ArrayType>([&](hw::ArrayType arrayType) -> FailureOr<Type> {
auto size =
@ -809,10 +808,18 @@ FailureOr<Type> hw::evaluateParametricType(Location loc, ArrayAttr parameters,
evaluateParametricType(loc, parameters, arrayType.getElementType());
if (failed(elementType))
return failure();
return hw::ArrayType::get(
arrayType.getContext(), elementType.getValue(),
IntegerAttr::get(IntegerType::get(type.getContext(), 64),
size.getValue().getSExtValue()));
// If the size was evaluated to a constant, use a 64-bit integer
// attribute version of it
if (auto intAttr = size->dyn_cast<IntegerAttr>())
return hw::ArrayType::get(
arrayType.getContext(), elementType.getValue(),
IntegerAttr::get(IntegerType::get(type.getContext(), 64),
intAttr.getValue().getSExtValue()));
// Otherwise parameter references are still involved
return hw::ArrayType::get(arrayType.getContext(),
elementType.getValue(), size.getValue());
})
.Default([&](auto) { return type; });
}

View File

@ -31,8 +31,6 @@ using namespace hw;
namespace {
using InstanceParameters = llvm::DenseMap<hw::HWModuleOp, ArrayAttr>;
// Generates a module name by composing the name of 'moduleOp' and the set of
// provided 'parameters'.
static std::string generateModuleName(SymbolCache &symbolCache,
@ -81,6 +79,37 @@ static FailureOr<Value> narrowValueToArrayWidth(OpBuilder &builder, Value array,
.getResult();
}
static hw::HWModuleOp targetModuleOp(hw::InstanceOp instanceOp,
const SymbolCache &sc) {
auto *targetOp = sc.getDefinition(instanceOp.moduleNameAttr());
auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
if (!targetHWModule)
return {}; // Won't specialize external modules.
if (targetHWModule.parameters().size() == 0)
return {}; // nothing to record or specialize
return targetHWModule;
}
// Stores unique module parameters and references to them
struct ParameterSpecializationRegistry {
llvm::MapVector<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
uniqueModuleParameters;
bool isRegistered(hw::HWModuleOp moduleOp, ArrayAttr parameters,
const SymbolCache &sc) const {
auto it = uniqueModuleParameters.find(moduleOp);
return it != uniqueModuleParameters.end() &&
it->second.contains(parameters);
}
void registerModuleOp(hw::HWModuleOp moduleOp, ArrayAttr parameters,
SymbolCache &sc) {
uniqueModuleParameters[moduleOp].insert(parameters);
}
};
struct EliminateParamValueOpPattern : public OpRewritePattern<ParamValueOp> {
EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters)
: OpRewritePattern<ParamValueOp>(context), parameters(parameters) {}
@ -88,12 +117,13 @@ struct EliminateParamValueOpPattern : public OpRewritePattern<ParamValueOp> {
LogicalResult matchAndRewrite(ParamValueOp op,
PatternRewriter &rewriter) const override {
// Substitute the param value op with an evaluated constant operation.
FailureOr<APInt> paramValue =
FailureOr<Attribute> evaluated =
evaluateParametricAttr(op.getLoc(), parameters, op.value());
if (failed(paramValue))
if (failed(evaluated))
return failure();
rewriter.replaceOpWithNewOp<hw::ConstantOp>(
op, op.getType(), paramValue.getValue().getSExtValue());
op, op.getType(),
evaluated->cast<IntegerAttr>().getValue().getSExtValue());
return success();
}
@ -101,8 +131,8 @@ struct EliminateParamValueOpPattern : public OpRewritePattern<ParamValueOp> {
};
// hw.array_get operations require indexes to be of equal width of the
// array itself. Since indexes may originate from constants or parameters, emit
// comb.extract operations to fulfill this invariant.
// array itself. Since indexes may originate from constants or parameters,
// emit comb.extract operations to fulfill this invariant.
struct NarrowArrayGetIndexPattern : public OpConversionPattern<ArrayGetOp> {
public:
using OpConversionPattern<ArrayGetOp>::OpConversionPattern;
@ -185,6 +215,54 @@ static void populateTypeConversion(Location loc, TypeConverter &typeConverter,
typeConverter.addConversion([](mlir::IntegerType type) { return type; });
}
// Registers any nested parametric instance ops of `target` for the next
// specialization loop
static LogicalResult registerNestedParametricInstanceOps(
HWModuleOp target, ArrayAttr parameters, SymbolCache &sc,
const ParameterSpecializationRegistry &currentRegistry,
ParameterSpecializationRegistry &nextRegistry,
llvm::DenseMap<hw::HWModuleOp,
llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
&parametersUsers) {
// Register any nested parametric instance ops for the next loop
auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult {
auto instanceParameters = instanceOp.parameters();
// We can ignore non-parametric instances
if (instanceParameters.empty())
return WalkResult::advance();
// Replace instance parameters with evaluated versions
llvm::SmallVector<Attribute> evaluatedInstanceParameters;
evaluatedInstanceParameters.reserve(instanceParameters.size());
for (auto instanceParameter : instanceParameters) {
auto instanceParameterDecl = instanceParameter.cast<hw::ParamDeclAttr>();
auto instanceParameterValue = instanceParameterDecl.getValue();
auto evaluated = evaluateParametricAttr(target.getLoc(), parameters,
instanceParameterValue);
if (failed(evaluated))
return WalkResult::interrupt();
evaluatedInstanceParameters.push_back(hw::ParamDeclAttr::get(
instanceParameterDecl.getName(), evaluated.getValue()));
}
auto evaluatedInstanceParametersAttr =
ArrayAttr::get(target.getContext(), evaluatedInstanceParameters);
if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
if (!currentRegistry.isRegistered(targetHWModule,
evaluatedInstanceParametersAttr, sc))
nextRegistry.registerModuleOp(targetHWModule,
evaluatedInstanceParametersAttr, sc);
parametersUsers[targetHWModule][evaluatedInstanceParametersAttr]
.push_back(instanceOp);
}
return WalkResult::advance();
});
return failure(walkResult.wasInterrupted());
}
// Specializes the provided 'base' module into the 'target' module. By doing
// so, we create a new module which
// 1. has no parameters
@ -193,9 +271,13 @@ static void populateTypeConversion(Location loc, TypeConverter &typeConverter,
// 3. Has a top-level interface with any parametric types resolved.
// 4. Any references to module parameters have been replaced with the
// parameter value.
static LogicalResult specializeModule(OpBuilder builder, ArrayAttr parameters,
SymbolCache &sc, HWModuleOp source,
HWModuleOp &target) {
static LogicalResult specializeModule(
OpBuilder builder, ArrayAttr parameters, SymbolCache &sc, HWModuleOp source,
HWModuleOp &target, const ParameterSpecializationRegistry &currentRegistry,
ParameterSpecializationRegistry &nextRegistry,
llvm::DenseMap<hw::HWModuleOp,
llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
&parametersUsers) {
auto *ctx = builder.getContext();
// Update the types of the source module ports based on evaluating any
// parametric in/output ports.
@ -243,9 +325,15 @@ static LogicalResult specializeModule(OpBuilder builder, ArrayAttr parameters,
mapper.set(oldRes, newRes);
}
// Register any nested parametric instance ops for the next loop
auto nestedRegistrationResult = registerNestedParametricInstanceOps(
target, parameters, sc, currentRegistry, nextRegistry, parametersUsers);
if (failed(nestedRegistrationResult))
return failure();
// We've now created a separate copy of the source module with a rewritten
// top-level interface. Next, we enter the module to convert parametric types
// within operations.
// top-level interface. Next, we enter the module to convert parametric
// types within operations.
RewritePatternSet patterns(ctx);
TypeConverter t;
populateTypeConversion(target.getLoc(), t, parameters);
@ -267,50 +355,73 @@ void HWSpecializePass::runOnOperation() {
ModuleOp module = getOperation();
// Record unique module parameters and references to these.
llvm::DenseMap<hw::HWModuleOp, llvm::SetVector<ArrayAttr>>
uniqueModuleParameters;
llvm::DenseMap<hw::HWModuleOp,
llvm::DenseMap<ArrayAttr, llvm::SmallVector<hw::InstanceOp>>>
parametersUsers;
ParameterSpecializationRegistry registry;
// Maintain a symbol cache for fast lookup during module specialization.
SymbolCache sc;
sc.addDefinitions(module);
for (auto hwModule : module.getOps<hw::HWModuleOp>()) {
// If this module is parametric, defer registering its parametric
// instantiations until this module is specialized
if (!hwModule.parameters().empty())
continue;
for (auto instanceOp : hwModule.getOps<hw::InstanceOp>()) {
auto *targetOp = sc.getDefinition(instanceOp.moduleNameAttr());
auto targetHWModule = dyn_cast<hw::HWModuleOp>(targetOp);
if (!targetHWModule)
continue; // Won't specialize external modules.
if (auto targetHWModule = targetModuleOp(instanceOp, sc)) {
auto parameters = instanceOp.parameters();
registry.registerModuleOp(targetHWModule, parameters, sc);
if (targetHWModule.parameters().size() == 0)
continue; // nothing to record or specializeauto paramValue =.
auto parameters = instanceOp.parameters();
uniqueModuleParameters[targetHWModule].insert(parameters);
parametersUsers[targetHWModule][parameters].push_back(instanceOp);
parametersUsers[targetHWModule][parameters].push_back(instanceOp);
}
}
}
// Create specialized modules.
OpBuilder builder = OpBuilder(&getContext());
builder.setInsertionPointToStart(module.getBody());
for (auto it : uniqueModuleParameters) {
for (auto parameters : it.getSecond()) {
HWModuleOp specializedModule;
if (failed(specializeModule(builder, parameters, sc, it.getFirst(),
specializedModule))) {
signalPassFailure();
return;
llvm::DenseMap<hw::HWModuleOp, llvm::DenseMap<ArrayAttr, hw::HWModuleOp>>
specializations;
// For every module specialization, any nested parametric modules will be
// registered for the next loop. We loop until no new nested modules have been
// registered.
while (!registry.uniqueModuleParameters.empty()) {
// The registry for the next specialization loop
ParameterSpecializationRegistry nextRegistry;
for (auto it : registry.uniqueModuleParameters) {
for (auto parameters : it.second) {
HWModuleOp specializedModule;
if (failed(specializeModule(builder, parameters, sc, it.first,
specializedModule, registry, nextRegistry,
parametersUsers))) {
signalPassFailure();
return;
}
// Extend the symbol cache with the newly created module.
sc.addDefinition(specializedModule.getNameAttr(), specializedModule);
// Add the specialization
specializations[it.first][parameters] = specializedModule;
}
}
// Extend the symbol cache with the newly created module.
sc.addDefinition(specializedModule.getNameAttr(), specializedModule);
// Transfer newly registered specializations to iterate over
registry.uniqueModuleParameters =
std::move(nextRegistry.uniqueModuleParameters);
}
// Rewrite instances of the specialized module to the specialized
// module.
for (auto instanceOp : parametersUsers[it.getFirst()][parameters]) {
// Rewrite instances of specialized modules to the specialized module.
for (auto it : specializations) {
auto unspecialized = it.getFirst();
auto &users = parametersUsers[unspecialized];
for (auto specialization : it.getSecond()) {
auto parameters = specialization.getFirst();
auto specializedModule = specialization.getSecond();
for (auto instanceOp : users[parameters]) {
instanceOp->setAttr("moduleName",
FlatSymbolRefAttr::get(specializedModule));
instanceOp->setAttr("parameters", ArrayAttr::get(&getContext(), {}));

View File

@ -1,7 +1,7 @@
// RUN: circt-opt %s -export-verilog -verify-diagnostics | FileCheck %s --strict-whitespace
// CHECK-LABEL: module inputs_only(
// CHECK-NEXT: input a,
// CHECK-NEXT: input a,
// CHECK-NEXT: b);
hw.module @inputs_only(%a: i1, %b: i1) {
hw.output
@ -372,6 +372,28 @@ hw.module @UseInstances(%a_in: i8) -> (a_out1: i1, a_out2: i1) {
hw.output %xyz.out, %xyz2.out : i1, i1
}
// Instantiate a parametric module using parameters from its parent module
hw.module.extern @ExternParametricWidth<width: i32>
(%in: !hw.int<#hw.param.decl.ref<"width">>) -> (out: !hw.int<#hw.param.decl.ref<"width">>)
// CHECK-LABEL: module NestedParameterUsage
hw.module @NestedParameterUsage<param: i32>(
%in: !hw.int<#hw.param.decl.ref<"param">>) -> (out: !hw.int<#hw.param.decl.ref<"param">>) {
// CHECK: #(parameter /*integer*/ param) (
// CHECK: input [param - 1:0] in,
// CHECK: output [param - 1:0] out);
// CHECK: ExternParametricWidth #(
// CHECK: .width(param)
// CHECK: ) externWidth (
// CHECK: .in (in),
// CHECK: .out (out)
// CHECK: );
// CHECK: endmodule
%externWidth.out = hw.instance "externWidth"
@ExternParametricWidth<width: i32 = #hw.param.decl.ref<"param">>(
in: %in : !hw.int<#hw.param.decl.ref<"param">>) -> (out: !hw.int<#hw.param.decl.ref<"param">>)
hw.output %externWidth.out : !hw.int<#hw.param.decl.ref<"param">>
}
// CHECK-LABEL: module Stop(
hw.module @Stop(%clock: i1, %reset: i1) {
// CHECK: always @(posedge clock) begin

View File

@ -126,3 +126,45 @@ module {
hw.output %0 : i64
}
}
// -----
// Test a parent module instantiating parametric modules using its parameters
module {
hw.module @constantGen<V: i32>() -> (out: i32) {
%0 = hw.param.value i32 = #hw.param.decl.ref<"V">
hw.output %0 : i32
}
hw.module @takesParametericWidth<width: i32>(
%in: !hw.int<#hw.param.decl.ref<"width">>) {}
hw.module @usesConstantGen<V: i32>(
%in: !hw.int<#hw.param.decl.ref<"V">>) -> (out: i32) {
%0 = hw.instance "inst1" @constantGen<V: i32 = #hw.param.decl.ref<"V">> () -> (out: i32)
hw.instance "inst2" @takesParametericWidth<width: i32 = #hw.param.decl.ref<"V">>(in: %in : !hw.int<#hw.param.decl.ref<"V">>) -> ()
hw.output %0 : i32
}
// CHECK-LABEL: hw.module @usesConstantGen_V_8(%in: i8) -> (out: i32) {
// CHECK: %[[VAL_0:.*]] = hw.instance "inst1" @constantGen_V_8() -> (out: i32)
// CHECK: hw.instance "inst2" @takesParametericWidth_width_8(in: %in: i8) -> ()
// CHECK: hw.output %[[VAL_0]] : i32
// CHECK: }
// CHECK-LABEL: hw.module @constantGen_V_8() -> (out: i32) {
// CHECK: %[[VAL_0:.*]] = hw.constant 8 : i32
// CHECK: hw.output %[[VAL_0]] : i32
// CHECK: }
// CHECK-LABEL: hw.module @takesParametericWidth_width_8(%in: i8) {
// CHECK: hw.output
// CHECK: }
hw.module @top() -> (out: i32) {
%in = hw.constant 1 : i8
%0 = hw.instance "inst" @usesConstantGen<V: i32 = 8> (in: %in: i8) -> (out: i32)
hw.output %0 : i32
}
}