[HW] Add IR parser/printer/verifier support for new-style instance parameters.

This commit is contained in:
Chris Lattner 2021-09-18 15:57:47 -07:00
parent 985a547ffc
commit 18746a3b4f
6 changed files with 226 additions and 90 deletions

View File

@ -20,9 +20,12 @@ def OutputFileAttr : StructAttr<"OutputFileAttr", HWDialect, [
DefaultValuedAttr<BoolAttr, "true">>,
]>;
/// An attribute describing a module parameter.
/// An attribute describing a module parameter, or instance parameter
/// specification.
def ParameterAttr : StructAttr<"ParameterAttr", HWDialect, [
/// This is the name of the parameter.
StructFieldAttr<"name", StrAttr>,
/// This is the MLIR type for it.
StructFieldAttr<"type", TypeAttr>,
/// This is the value of the attribute - in a module, this is the default
/// value (and may be missing). In an instance, this is a required field that
@ -32,6 +35,8 @@ def ParameterAttr : StructAttr<"ParameterAttr", HWDialect, [
StructFieldAttr<"value", OptionalAttr<AnyAttr>>
]>;
/// An array of ParameterAttr's.
/// An array of ParameterAttr's that may or may not have a 'value' specified,
/// to be used on hw.module or hw.instance. The hw.instance verifier further
/// ensures that all the values are specified.
def ParameterArrayAttr
: TypedArrayAttrBase<ParameterAttr, "parameter array attribute">;

View File

@ -326,6 +326,7 @@ def InstanceOp : HWOp<"instance", [HasParent<"HWModuleOp">, Symbol,
Confined<FlatSymbolRefAttr, [isModuleSymbol]>:$moduleName,
Variadic<AnyType>:$inputs,
StrArrayAttr:$argNames, StrArrayAttr:$resultNames,
ParameterArrayAttr:$parameters,
OptionalAttr<DictionaryAttr>:$oldParameters,
OptionalAttr<SymbolNameAttr>:$sym_name);
let results = (outs Variadic<AnyType>);
@ -334,12 +335,27 @@ def InstanceOp : HWOp<"instance", [HasParent<"HWModuleOp">, Symbol,
/// Create a instance that refers to a known module.
OpBuilder<(ins "Operation*":$module, "StringAttr":$name,
"ArrayRef<Value>":$inputs,
CArg<"DictionaryAttr", "{}">:$oldParameters,
CArg<"ArrayAttr", "{}">:$parameters,
CArg<"StringAttr", "{}">:$sym_name)>,
/// Create a instance that refers to a known module.
OpBuilder<(ins "Operation*":$module, "StringRef":$name,
"ArrayRef<Value>":$inputs,
CArg<"DictionaryAttr", "{}">:$oldParameters,
CArg<"ArrayAttr", "{}">:$parameters,
CArg<"StringAttr", "{}">:$sym_name), [{
build($_builder, $_state, module, $_builder.getStringAttr(name), inputs,
parameters, sym_name);
}]>,
/// TODO: Remove these builders that support the oldParameter format.
/// Create a instance that refers to a known module.
OpBuilder<(ins "Operation*":$module, "StringAttr":$name,
"ArrayRef<Value>":$inputs,
"DictionaryAttr":$oldParameters,
CArg<"StringAttr", "{}">:$sym_name)>,
/// Create a instance that refers to a known module.
OpBuilder<(ins "Operation*":$module, "StringRef":$name,
"ArrayRef<Value>":$inputs,
"DictionaryAttr":$oldParameters,
CArg<"StringAttr", "{}">:$sym_name), [{
build($_builder, $_state, module, $_builder.getStringAttr(name), inputs,
oldParameters, sym_name);

View File

@ -2092,7 +2092,8 @@ LogicalResult FIRRTLLowering::visitDecl(MemOp op) {
auto inst = builder.create<hw::InstanceOp>(
resultTypes, builder.getStringAttr(memName), memModuleAttr, operands,
builder.getArrayAttr(argNames), builder.getArrayAttr(resultNames),
DictionaryAttr(), StringAttr());
/*parameters=*/builder.getArrayAttr({}), /*oldParams=*/DictionaryAttr(),
/*sym_name=*/StringAttr());
// Update all users of the result of read ports
for (auto &ret : returnHolder)
(void)setLowering(ret.first->getResult(0), inst.getResult(ret.second));

View File

@ -556,6 +556,23 @@ FunctionType getHWModuleOpType(Operation *op) {
return typeAttr.getValue().cast<FunctionType>();
}
/// Print a parameter list for a module or instance.
static void printParameterList(ArrayAttr parameters, OpAsmPrinter &p) {
if (parameters.empty())
return;
p << '<';
llvm::interleaveComma(parameters, p, [&](Attribute param) {
auto paramAttr = param.cast<ParameterAttr>();
p << paramAttr.name().getValue() << ": " << paramAttr.type();
if (auto value = paramAttr.value()) {
p << " = ";
p.printAttributeWithoutType(value);
}
});
p << '>';
}
static void printModuleOp(OpAsmPrinter &p, Operation *op,
ExternModKind modKind) {
using namespace mlir::function_like_impl;
@ -573,19 +590,7 @@ static void printModuleOp(OpAsmPrinter &p, Operation *op,
}
// Print the parameter list if present.
auto parameters = op->getAttrOfType<ArrayAttr>("parameters");
if (!parameters.empty()) {
p << '<';
llvm::interleaveComma(parameters, p, [&](Attribute param) {
auto paramAttr = param.cast<ParameterAttr>();
p << paramAttr.name().getValue() << ": " << paramAttr.type();
if (auto value = paramAttr.value()) {
p << " = ";
p.printAttributeWithoutType(value);
}
});
p << '>';
}
printParameterList(op->getAttrOfType<ArrayAttr>("parameters"), p);
bool needArgNamesAttr = false;
module_like_impl::printModuleSignature(p, op, argTypes, /*isVariadic=*/false,
@ -681,6 +686,25 @@ static LogicalResult verifyHWModuleGeneratedOp(HWModuleGeneratedOp op) {
// InstanceOp
//===----------------------------------------------------------------------===//
/// Create a instance that refers to a known module.
void InstanceOp::build(OpBuilder &builder, OperationState &result,
Operation *module, StringAttr name,
ArrayRef<Value> inputs, ArrayAttr parameters,
StringAttr sym_name) {
assert(isAnyModule(module) && "Can only reference a module");
if (!parameters)
parameters = builder.getArrayAttr({});
FunctionType modType = getModuleType(module);
build(builder, result, modType.getResults(), name,
FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
module->getAttrOfType<ArrayAttr>("argNames"),
module->getAttrOfType<ArrayAttr>("resultNames"), parameters,
/*oldParameters*/ DictionaryAttr(), sym_name);
}
/// TODO: Remove these builders that support the oldParameter format.
/// Create a instance that refers to a known module.
void InstanceOp::build(OpBuilder &builder, OperationState &result,
Operation *module, StringAttr name,
@ -691,8 +715,8 @@ void InstanceOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, modType.getResults(), name,
FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
module->getAttrOfType<ArrayAttr>("argNames"),
module->getAttrOfType<ArrayAttr>("resultNames"), oldParameters,
sym_name);
module->getAttrOfType<ArrayAttr>("resultNames"),
/*parameters*/ builder.getArrayAttr({}), oldParameters, sym_name);
}
/// Lookup the module or extmodule for the symbol. This returns null on
@ -708,6 +732,16 @@ Operation *InstanceOp::getReferencedModule(const SymbolCache *cache) {
// Helper function to verify instance op types
static LogicalResult verifyInstanceOpTypes(InstanceOp op, Operation *module) {
// Emit an error message on the instance, with a note indicating which module
// is being referenced.
auto emitError =
[&](std::function<void(InFlightDiagnostic & diag)> fn) -> LogicalResult {
auto diag = op.emitOpError();
fn(diag);
diag.attachNote(module->getLoc()) << "module declared here";
return failure();
};
// Make sure our port and result names match.
ArrayAttr argNames = op.argNamesAttr();
ArrayAttr modArgNames = module->getAttrOfType<ArrayAttr>("argNames");
@ -716,41 +750,32 @@ static LogicalResult verifyInstanceOpTypes(InstanceOp op, Operation *module) {
auto numOperands = op->getNumOperands();
auto expectedOperandTypes = getModuleType(module).getInputs();
if (expectedOperandTypes.size() != numOperands) {
auto diag = op.emitOpError()
<< "has a wrong number of operands; expected "
<< expectedOperandTypes.size() << " but got " << numOperands;
diag.attachNote(module->getLoc()) << "original module declared here";
return failure();
}
if (expectedOperandTypes.size() != numOperands)
return emitError([&](auto &diag) {
diag << "has a wrong number of operands; expected "
<< expectedOperandTypes.size() << " but got " << numOperands;
});
if (argNames.size() != numOperands) {
auto diag = op.emitOpError()
<< "has a wrong number of input port names; expected "
<< numOperands << " but got " << argNames.size();
diag.attachNote(module->getLoc()) << "original module declared here";
return failure();
}
if (argNames.size() != numOperands)
return emitError([&](auto &diag) {
diag << "has a wrong number of input port names; expected " << numOperands
<< " but got " << argNames.size();
});
for (size_t i = 0; i != numOperands; ++i) {
auto expectedType = expectedOperandTypes[i];
auto operandType = op.getOperand(i).getType();
if (operandType != expectedType) {
auto diag = op.emitOpError()
<< "operand type #" << i << " must be " << expectedType
<< ", but got " << operandType;
diag.attachNote(module->getLoc()) << "original module declared here";
return failure();
}
if (operandType != expectedType)
return emitError([&](auto &diag) {
diag << "operand type #" << i << " must be " << expectedType
<< ", but got " << operandType;
});
if (argNames[i] != modArgNames[i]) {
auto diag = op.emitOpError()
<< "input label #" << i << " must be " << modArgNames[i]
<< ", but got " << argNames[i];
diag.attachNote(module->getLoc()) << "original module declared here";
module->dump();
return failure();
}
if (argNames[i] != modArgNames[i])
return emitError([&](auto &diag) {
diag << "input label #" << i << " must be " << modArgNames[i]
<< ", but got " << argNames[i];
});
}
// Check result types and labels.
@ -759,39 +784,64 @@ static LogicalResult verifyInstanceOpTypes(InstanceOp op, Operation *module) {
ArrayAttr resultNames = op.resultNamesAttr();
ArrayAttr modResultNames = module->getAttrOfType<ArrayAttr>("resultNames");
if (expectedResultTypes.size() != numResults) {
auto diag = op.emitOpError()
<< "has a wrong number of results; expected "
<< expectedResultTypes.size() << " but got " << numResults;
diag.attachNote(module->getLoc()) << "original module declared here";
return failure();
}
if (resultNames.size() != numResults) {
auto diag = op.emitOpError()
<< "has a wrong number of results port labels; expected "
<< numResults << " but got " << resultNames.size();
diag.attachNote(module->getLoc()) << "original module declared here";
return failure();
}
if (expectedResultTypes.size() != numResults)
return emitError([&](auto &diag) {
diag << "has a wrong number of results; expected "
<< expectedResultTypes.size() << " but got " << numResults;
});
if (resultNames.size() != numResults)
return emitError([&](auto &diag) {
diag << "has a wrong number of results port labels; expected "
<< numResults << " but got " << resultNames.size();
});
for (size_t i = 0; i != numResults; ++i) {
auto expectedType = expectedResultTypes[i];
auto resultType = op.getResult(i).getType();
if (resultType != expectedType) {
auto diag = op.emitOpError()
<< "result type #" << i << " must be " << expectedType
<< ", but got " << resultType;
diag.attachNote(module->getLoc()) << "original module declared here";
return failure();
}
if (resultType != expectedType)
return emitError([&](auto &diag) {
diag << "result type #" << i << " must be " << expectedType
<< ", but got " << resultType;
});
if (resultNames[i] != modResultNames[i]) {
auto diag = op.emitOpError()
<< "input label #" << i << " must be " << modResultNames[i]
<< ", but got " << resultNames[i];
diag.attachNote(module->getLoc()) << "original module declared here";
return failure();
}
if (resultNames[i] != modResultNames[i])
return emitError([&](auto &diag) {
diag << "input label #" << i << " must be " << modResultNames[i]
<< ", but got " << resultNames[i];
});
}
// Check parameters match up.
ArrayAttr parameters = op.parameters();
ArrayAttr modParameters = module->getAttrOfType<ArrayAttr>("parameters");
auto numParameters = parameters.size();
if (numParameters != modParameters.size())
return emitError([&](auto &diag) {
diag << "expected " << modParameters.size() << " parameters but had "
<< numParameters;
});
for (size_t i = 0; i != numParameters; ++i) {
auto param = parameters[i].cast<ParameterAttr>();
auto modParam = modParameters[i].cast<ParameterAttr>();
if (param.name() != modParam.name())
return emitError([&](auto &diag) {
diag << "parameter #" << i << " should have name " << modParam.name()
<< " but has name " << param.name();
});
if (param.type() != modParam.type())
return emitError([&](auto &diag) {
diag << "parameter " << param.name() << " should have type "
<< modParam.type() << " but has type " << param.type();
});
// All instance parameters must have a value. Specify the same value as
// a module's default value if you want the default.
if (!param.value())
return op.emitOpError("parameter ")
<< param.name() << " must have a value";
}
return success();
@ -820,7 +870,7 @@ static ParseResult parseInstanceOp(OpAsmParser &parser,
SmallVector<OpAsmParser::OperandType, 4> inputsOperands;
SmallVector<Type> inputsTypes;
SmallVector<Type> allResultTypes;
SmallVector<Attribute> argNames, resultNames;
SmallVector<Attribute> argNames, resultNames, parameters;
auto noneType = parser.getBuilder().getType<NoneType>();
if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName",
@ -853,9 +903,11 @@ static ParseResult parseInstanceOp(OpAsmParser &parser,
return parser.parseColonType(allResultTypes.back());
};
llvm::SMLoc inputsOperandsLoc;
llvm::SMLoc parametersLoc, inputsOperandsLoc;
if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName",
result.attributes) ||
parser.getCurrentLocation(&parametersLoc) ||
parseOptionalParameters(parser, parameters) ||
parser.getCurrentLocation(&inputsOperandsLoc) ||
parseCommaSeparatedList(parser, Delimiter::Paren, parseInputPort) ||
parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
@ -869,6 +921,8 @@ static ParseResult parseInstanceOp(OpAsmParser &parser,
result.addAttribute("argNames", parser.getBuilder().getArrayAttr(argNames));
result.addAttribute("resultNames",
parser.getBuilder().getArrayAttr(resultNames));
result.addAttribute("parameters",
parser.getBuilder().getArrayAttr(parameters));
result.addTypes(allResultTypes);
return success();
}
@ -896,6 +950,7 @@ static void printInstanceOp(OpAsmPrinter &p, InstanceOp op) {
}
p << ' ';
p.printAttributeWithoutType(op.moduleNameAttr());
printParameterList(op.parameters(), p);
p << '(';
llvm::interleaveComma(op.inputs(), p, [&](Value op) {
printPortName(nextInputPort, portInfo.inputs);
@ -907,9 +962,9 @@ static void printInstanceOp(OpAsmPrinter &p, InstanceOp op) {
p << res.getType();
});
p << ')';
p.printOptionalAttrDict(
op->getAttrs(), /*elidedAttrs=*/{"instanceName", "sym_name", "moduleName",
"argNames", "resultNames"});
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{
"instanceName", "sym_name", "moduleName",
"argNames", "resultNames", "parameters"});
}
/// Return the name of the specified input port or null if it cannot be

View File

@ -132,7 +132,7 @@ hw.module @invalid_add(%a: i0) { // i0 ports are ok.
// -----
// expected-note @+1 {{original module declared here}}
// expected-note @+1 {{module declared here}}
hw.module @empty() -> () {
hw.output
}
@ -145,7 +145,7 @@ hw.module @test() -> () {
// -----
// expected-note @+1 {{original module declared here}}
// expected-note @+1 {{module declared here}}
hw.module @f() -> (a: i2) {
%a = hw.constant 1 : i2
hw.output %a : i2
@ -159,7 +159,7 @@ hw.module @test() -> () {
// -----
// expected-note @+1 {{original module declared here}}
// expected-note @+1 {{module declared here}}
hw.module @empty() -> () {
hw.output
}
@ -172,7 +172,7 @@ hw.module @test(%a: i1) -> () {
// -----
// expected-note @+1 {{original module declared here}}
// expected-note @+1 {{module declared here}}
hw.module @f(%a: i1) -> () {
hw.output
}
@ -186,7 +186,7 @@ hw.module @test(%a: i2) -> () {
// -----
// expected-note @+1 {{original module declared here}}
// expected-note @+1 {{module declared here}}
hw.module @f(%a: i1) -> () {
hw.output
}
@ -196,3 +196,47 @@ hw.module @test(%a: i1) -> () {
hw.instance "test" @f(b: %a: i1) -> ()
hw.output
}
// -----
// expected-note @+1 {{module declared here}}
hw.module.extern @p<p1: i42 = 17, p2: i1>(%arg0: i8) -> (out: i8)
hw.module @Use(%a: i8) -> (xx: i8) {
// expected-error @+1 {{op expected 2 parameters but had 1}}
%r0 = hw.instance "inst1" @p<p1: i42 = 4>(arg0: %a: i8) -> (out: i8)
hw.output %r0: i8
}
// -----
// expected-note @+1 {{module declared here}}
hw.module.extern @p<p1: i42 = 17, p2: i1>(%arg0: i8) -> (out: i8)
hw.module @Use(%a: i8) -> (xx: i8) {
// expected-error @+1 {{op parameter #1 should have name "p2" but has name "p3"}}
%r0 = hw.instance "inst1" @p<p1: i42 = 4, p3: i1 = 0>(arg0: %a: i8) -> (out: i8)
hw.output %r0: i8
}
// -----
// expected-note @+1 {{module declared here}}
hw.module.extern @p<p1: i42 = 17, p2: i1>(%arg0: i8) -> (out: i8)
hw.module @Use(%a: i8) -> (xx: i8) {
// expected-error @+1 {{op parameter "p2" should have type i1 but has type i2}}
%r0 = hw.instance "inst1" @p<p1: i42 = 4, p2: i2 = 0>(arg0: %a: i8) -> (out: i8)
hw.output %r0: i8
}
// -----
hw.module.extern @p<p1: i42 = 17, p2: i1>(%arg0: i8) -> (out: i8)
hw.module @Use(%a: i8) -> (xx: i8) {
// expected-error @+1 {{op parameter "p2" must have a value}}
%r0 = hw.instance "inst1" @p<p1: i42 = 4, p2: i1>(arg0: %a: i8) -> (out: i8)
hw.output %r0: i8
}

View File

@ -58,7 +58,22 @@ hw.module.generated @genmod1, @MEMORY() -> (FOOBAR: i1) attributes {write_latenc
// CHECK-LABEL: hw.module.extern @AnonArg(i42)
hw.module.extern @AnonArg(i42)
// CHECK-LABEL: hw.module @parameters<p1: i42 = 17, p2: i1>(%arg0: si8) -> (out: si8) {
hw.module @parameters<p1: i42 = 17, p2: i1>(%arg0: si8) -> (out: si8) {
hw.output %arg0 : si8
// CHECK-LABEL: hw.module @parameters<p1: i42 = 17, p2: i1>(%arg0: i8) -> (out: i8) {
hw.module @parameters<p1: i42 = 17, p2: i1>(%arg0: i8) -> (out: i8) {
hw.output %arg0 : i8
}
// CHECK-LABEL: hw.module @UseParameterized(
hw.module @UseParameterized(%a: i8) -> (xx: i8, yy: i8, zz: i8) {
// CHECK: %inst1.out = hw.instance "inst1" @parameters<p1: i42 = 4, p2: i1 = false>(arg0:
%r0 = hw.instance "inst1" @parameters<p1: i42 = 4, p2: i1 = 0>(arg0: %a: i8) -> (out: i8)
// CHECK: %inst2.out = hw.instance "inst2" @parameters<p1: i42 = 11, p2: i1 = true>(arg0:
%r1 = hw.instance "inst2" @parameters<p1: i42 = 11, p2: i1 = 1>(arg0: %a: i8) -> (out: i8)
// CHECK: %inst3.out = hw.instance "inst3" @parameters<p1: i42 = 17, p2: i1 = false>(arg0:
%r2 = hw.instance "inst3" @parameters<p1: i42 = 17, p2: i1 = 0>(arg0: %a: i8) -> (out: i8)
hw.output %r0, %r1, %r2: i8, i8, i8
}