[FIRRTL] Add location to getResultType for better error reporting (#203)

* change getTypeSignature

* Apply for bitops

* Add test for bits

* Modify a bit

* clang-format

* fit into 80
This commit is contained in:
Hideto Ueno 2020-11-06 12:11:28 +09:00 committed by GitHub
parent 08f04b74f6
commit 4d31c29d2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 107 additions and 67 deletions

View File

@ -50,7 +50,8 @@ def SubfieldOp : FIRRTLOp<"subfield", [NoSideEffect]> {
/// Compute the result of a Subfield operation on a value of the specified
/// type and extracting the specified field name. If the request is
/// invalid, then a null type is returned.
static FIRRTLType getResultType(FIRRTLType inType, StringRef fieldName);
static FIRRTLType getResultType(FIRRTLType inType, StringRef fieldName,
Location loc);
}];
}
@ -77,7 +78,8 @@ def SubindexOp : FIRRTLOp<"subindex", [NoSideEffect]> {
let extraClassDeclaration = [{
/// Compute the result of a Subindex operation on a value of the specified
/// type. If the request is invalid, then a null type is returned.
static FIRRTLType getResultType(FIRRTLType inType, unsigned fieldIdx);
static FIRRTLType getResultType(FIRRTLType inType, unsigned fieldIdx,
Location loc);
}];
}
@ -103,7 +105,8 @@ def SubaccessOp : FIRRTLOp<"subaccess", [NoSideEffect]> {
let extraClassDeclaration = [{
/// Compute the result of a Subaccess operation on a value of the specified
/// type. If the request is invalid, then a null type is returned.
static FIRRTLType getResultType(FIRRTLType baseType, FIRRTLType indexType);
static FIRRTLType getResultType(FIRRTLType baseType, FIRRTLType indexType,
Location loc);
}];
}
//===----------------------------------------------------------------------===//
@ -133,14 +136,15 @@ class BinaryPrimOp<string mnemonic, string resultTypeFunction,
let extraClassDeclaration = !cast<code>(!strconcat(!cast<string>([{
/// Return the result for inputs with the specified type, returning a null
/// type if the input types are invalid.
static FIRRTLType getResultType(FIRRTLType lhs, FIRRTLType rhs) {
static FIRRTLType getResultType(FIRRTLType lhs, FIRRTLType rhs,
Location loc) {
return }]), resultTypeFunction, !cast<string>([{(lhs, rhs);
}
static FIRRTLType getResultType(ArrayRef<FIRRTLType> inputs,
ArrayRef<int32_t> integers) {
ArrayRef<int32_t> integers, Location loc) {
if (inputs.size() != 2 || !integers.empty())
return {};
return getResultType(inputs[0], inputs[1]);
return getResultType(inputs[0], inputs[1], loc);
}
}])));
}
@ -197,14 +201,14 @@ class UnaryPrimOp<string mnemonic, string resultTypeFunction,
let extraClassDeclaration = !cast<code>(!strconcat(!cast<string>([{
/// Return the result for inputs with the specified type, returning a null
/// type if the input types are invalid.
static FIRRTLType getResultType(FIRRTLType input) {
static FIRRTLType getResultType(FIRRTLType input, Location loc) {
return }]), resultTypeFunction, !cast<string>([{(input);
}
static FIRRTLType getResultType(ArrayRef<FIRRTLType> inputs,
ArrayRef<int32_t> integers) {
ArrayRef<int32_t> integers, Location loc) {
if (inputs.size() != 1 || !integers.empty())
return {};
return getResultType(inputs[0]);
return getResultType(inputs[0], loc);
}
}])));
}
@ -252,12 +256,12 @@ def BitsPrimOp : PrimOp<"bits"> {
/// Return the result for inputs with the specified type, returning a null
/// type if the input types are invalid.
static FIRRTLType getResultType(FIRRTLType input, int32_t high,
int32_t low);
int32_t low, Location loc);
static FIRRTLType getResultType(ArrayRef<FIRRTLType> inputs,
ArrayRef<int32_t> integers) {
ArrayRef<int32_t> integers, Location loc) {
if (inputs.size() != 1 || integers.size() != 2)
return {};
return getResultType(inputs[0], integers[0], integers[1]);
return getResultType(inputs[0], integers[0], integers[1], loc);
}
}];
@ -277,12 +281,13 @@ def HeadPrimOp : PrimOp<"head"> {
let extraClassDeclaration = [{
/// Return the result for inputs with the specified type, returning a null
/// type if the input types are invalid.
static FIRRTLType getResultType(FIRRTLType input, int32_t amount);
static FIRRTLType getResultType(FIRRTLType input, int32_t amount,
Location loc);
static FIRRTLType getResultType(ArrayRef<FIRRTLType> inputs,
ArrayRef<int32_t> integers) {
ArrayRef<int32_t> integers, Location loc) {
if (inputs.size() != 1 || integers.size() != 1)
return {};
return getResultType(inputs[0], integers[0]);
return getResultType(inputs[0], integers[0], loc);
}
}];
}
@ -301,12 +306,12 @@ def MuxPrimOp : PrimOp<"mux"> {
/// Return the result for inputs with the specified type, returning a null
/// type if the input types are invalid.
static FIRRTLType getResultType(FIRRTLType sel, FIRRTLType high,
FIRRTLType low);
FIRRTLType low, Location loc);
static FIRRTLType getResultType(ArrayRef<FIRRTLType> inputs,
ArrayRef<int32_t> integers) {
ArrayRef<int32_t> integers, Location loc) {
if (inputs.size() != 3 || integers.size() != 0)
return {};
return getResultType(inputs[0], inputs[1], inputs[2]);
return getResultType(inputs[0], inputs[1], inputs[2], loc);
}
}];
}
@ -330,12 +335,13 @@ def PadPrimOp : PrimOp<"pad"> {
let extraClassDeclaration = [{
/// Return the result for inputs with the specified type, returning a null
/// type if the input types are invalid.
static FIRRTLType getResultType(FIRRTLType input, int32_t amount);
static FIRRTLType getResultType(FIRRTLType input, int32_t amount,
Location loc);
static FIRRTLType getResultType(ArrayRef<FIRRTLType> inputs,
ArrayRef<int32_t> integers) {
ArrayRef<int32_t> integers, Location loc) {
if (inputs.size() != 1 || integers.size() != 1)
return {};
return getResultType(inputs[0], integers[0]);
return getResultType(inputs[0], integers[0], loc);
}
}];
}
@ -351,12 +357,13 @@ class ShiftPrimOp<string mnemonic> : PrimOp<mnemonic> {
let extraClassDeclaration = [{
/// Return the result for inputs with the specified type, returning a null
/// type if the input types are invalid.
static FIRRTLType getResultType(FIRRTLType input, int32_t amount);
static FIRRTLType getResultType(FIRRTLType input, int32_t amount,
Location loc);
static FIRRTLType getResultType(ArrayRef<FIRRTLType> inputs,
ArrayRef<int32_t> integers) {
ArrayRef<int32_t> integers, Location loc) {
if (inputs.size() != 1 || integers.size() != 1)
return {};
return getResultType(inputs[0], integers[0]);
return getResultType(inputs[0], integers[0], loc);
}
}];
}
@ -385,12 +392,13 @@ def TailPrimOp : PrimOp<"tail"> {
let extraClassDeclaration = [{
/// Return the result for inputs with the specified type, returning a null
/// type if the input types are invalid.
static FIRRTLType getResultType(FIRRTLType input, int32_t amount);
static FIRRTLType getResultType(FIRRTLType input, int32_t amount,
Location loc);
static FIRRTLType getResultType(ArrayRef<FIRRTLType> inputs,
ArrayRef<int32_t> integers) {
ArrayRef<int32_t> integers, Location loc) {
if (inputs.size() != 1 || integers.size() != 1)
return {};
return getResultType(inputs[0], integers[0]);
return getResultType(inputs[0], integers[0], loc);
}
}];
}

View File

@ -4,6 +4,7 @@
#include "circt/Dialect/FIRRTL/Ops.h"
#include "circt/Dialect/FIRRTL/Visitors.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/StandardTypes.h"
@ -795,7 +796,8 @@ void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type,
}
// Return the result of a subfield operation.
FIRRTLType SubfieldOp::getResultType(FIRRTLType inType, StringRef fieldName) {
FIRRTLType SubfieldOp::getResultType(FIRRTLType inType, StringRef fieldName,
Location loc) {
if (auto bundleType = inType.dyn_cast<BundleType>()) {
for (auto &elt : bundleType.getElements()) {
if (elt.first == fieldName)
@ -804,31 +806,33 @@ FIRRTLType SubfieldOp::getResultType(FIRRTLType inType, StringRef fieldName) {
}
if (auto flipType = inType.dyn_cast<FlipType>())
if (auto subType = getResultType(flipType.getElementType(), fieldName))
if (auto subType = getResultType(flipType.getElementType(), fieldName, loc))
return FlipType::get(subType);
return {};
}
FIRRTLType SubindexOp::getResultType(FIRRTLType inType, unsigned fieldIdx) {
FIRRTLType SubindexOp::getResultType(FIRRTLType inType, unsigned fieldIdx,
Location loc) {
if (auto vectorType = inType.dyn_cast<FVectorType>())
if (fieldIdx < vectorType.getNumElements())
return vectorType.getElementType();
if (auto flipType = inType.dyn_cast<FlipType>())
if (auto subType = getResultType(flipType.getElementType(), fieldIdx))
if (auto subType = getResultType(flipType.getElementType(), fieldIdx, loc))
return FlipType::get(subType);
return {};
}
FIRRTLType SubaccessOp::getResultType(FIRRTLType inType, FIRRTLType indexType) {
FIRRTLType SubaccessOp::getResultType(FIRRTLType inType, FIRRTLType indexType,
Location loc) {
if (auto vectorType = inType.dyn_cast<FVectorType>())
if (indexType.isa<UIntType>())
return vectorType.getElementType();
if (auto flipType = inType.dyn_cast<FlipType>())
if (auto subType = getResultType(flipType.getElementType(), indexType))
if (auto subType = getResultType(flipType.getElementType(), indexType, loc))
return FlipType::get(subType);
return {};
@ -1076,31 +1080,20 @@ FIRRTLType firrtl::getReductionResult(FIRRTLType input) {
static LogicalResult verifyBitsPrimOp(BitsPrimOp bits) {
uint32_t hi = bits.hi(), lo = bits.lo();
// High must be >= low.
if (hi < lo) {
bits.emitError()
<< "high must be equal or greater than low, but got high = " << hi
<< ", low = " << lo;
auto expectedType =
BitsPrimOp::getResultType(bits.input().getType().cast<FIRRTLType>(),
(int32_t)hi, (int32_t)lo, bits.getLoc());
if (!expectedType)
return failure();
}
// Input width must be > high.
int32_t width =
bits.input().getType().cast<IntType>().getBitWidthOrSentinel();
if (width != -1 && int32_t(hi) >= width) {
bits.emitError()
<< "high must be smaller than the width of input, but got high = " << hi
<< ", width = " << width;
return failure();
}
// Result type should be int type with (high - low + 1) width.
int32_t resultWidth =
bits.result().getType().cast<IntType>().getBitWidthOrSentinel();
if (resultWidth != -1 && int32_t(hi - lo + 1) != resultWidth) {
int32_t expectedWidth = expectedType.cast<IntType>().getBitWidthOrSentinel();
if (resultWidth != -1 && expectedWidth != resultWidth) {
bits.emitError() << "width of the result type must be equal to (high - low "
"+ 1), expected "
<< hi - lo + 1 << " but got " << resultWidth;
<< expectedWidth << " but got " << resultWidth;
return failure();
}
@ -1108,30 +1101,51 @@ static LogicalResult verifyBitsPrimOp(BitsPrimOp bits) {
}
FIRRTLType BitsPrimOp::getResultType(FIRRTLType input, int32_t high,
int32_t low) {
int32_t low, Location loc) {
auto inputi = input.dyn_cast<IntType>();
// High must be >= low and both most be non-negative.
if (!inputi || high < low || low < 0)
if (!inputi) {
mlir::emitError(loc) << "input type should be the int type but got "
<< input;
return {};
}
// High must be >= low and both most be non-negative.
if (high < low) {
mlir::emitError(loc)
<< "high must be equal or greater than low, but got high = " << high
<< ", low = " << low;
return {};
}
if (low < 0) {
mlir::emitError(loc) << "low must be non-negative but got" << low;
return {};
}
// If the input has staticly known width, check it. Both and low must be
// strictly less than width.
int32_t width = inputi.getWidthOrSentinel();
if (width != -1 && high >= width)
if (width != -1 && high >= width) {
mlir::emitError(loc)
<< "high must be smaller than the width of input, but got high = "
<< high << ", width = " << width;
return {};
}
return UIntType::get(input.getContext(), high - low + 1);
}
void BitsPrimOp::build(OpBuilder &builder, OperationState &result, Value input,
unsigned high, unsigned low) {
auto type = getResultType(input.getType().cast<FIRRTLType>(), high, low);
auto type = getResultType(input.getType().cast<FIRRTLType>(), high, low,
result.location);
assert(type && "invalid inputs building BitsPrimOp!");
build(builder, result, type, input, high, low);
}
FIRRTLType HeadPrimOp::getResultType(FIRRTLType input, int32_t amount) {
FIRRTLType HeadPrimOp::getResultType(FIRRTLType input, int32_t amount,
Location loc) {
auto inputi = input.dyn_cast<IntType>();
if (amount <= 0 || !inputi)
return {};
@ -1145,7 +1159,7 @@ FIRRTLType HeadPrimOp::getResultType(FIRRTLType input, int32_t amount) {
}
FIRRTLType MuxPrimOp::getResultType(FIRRTLType sel, FIRRTLType high,
FIRRTLType low) {
FIRRTLType low, Location loc) {
// Sel needs to be a one bit uint or an unknown width uint.
auto selui = sel.dyn_cast<UIntType>();
if (!selui || selui.getWidthOrSentinel() > 1)
@ -1191,7 +1205,8 @@ FIRRTLType MuxPrimOp::getResultType(FIRRTLType sel, FIRRTLType high,
return {};
}
FIRRTLType PadPrimOp::getResultType(FIRRTLType input, int32_t amount) {
FIRRTLType PadPrimOp::getResultType(FIRRTLType input, int32_t amount,
Location loc) {
auto inputi = input.dyn_cast<IntType>();
if (amount < 0 || !inputi)
return {};
@ -1204,7 +1219,8 @@ FIRRTLType PadPrimOp::getResultType(FIRRTLType input, int32_t amount) {
return IntType::get(input.getContext(), inputi.isSigned(), width);
}
FIRRTLType ShlPrimOp::getResultType(FIRRTLType input, int32_t amount) {
FIRRTLType ShlPrimOp::getResultType(FIRRTLType input, int32_t amount,
Location loc) {
auto inputi = input.dyn_cast<IntType>();
if (amount < 0 || !inputi)
return {};
@ -1216,7 +1232,8 @@ FIRRTLType ShlPrimOp::getResultType(FIRRTLType input, int32_t amount) {
return IntType::get(input.getContext(), inputi.isSigned(), width);
}
FIRRTLType ShrPrimOp::getResultType(FIRRTLType input, int32_t amount) {
FIRRTLType ShrPrimOp::getResultType(FIRRTLType input, int32_t amount,
Location loc) {
auto inputi = input.dyn_cast<IntType>();
if (amount < 0 || !inputi)
return {};
@ -1228,7 +1245,8 @@ FIRRTLType ShrPrimOp::getResultType(FIRRTLType input, int32_t amount) {
return IntType::get(input.getContext(), inputi.isSigned(), width);
}
FIRRTLType TailPrimOp::getResultType(FIRRTLType input, int32_t amount) {
FIRRTLType TailPrimOp::getResultType(FIRRTLType input, int32_t amount,
Location loc) {
auto inputi = input.dyn_cast<IntType>();
if (amount < 0 || !inputi)
return {};

View File

@ -902,7 +902,8 @@ ParseResult FIRStmtParser::parsePostFixFieldId(Value &result,
// Make sure the field name matches up with the input value's type and
// compute the result type for the expression.
auto resultType = result.getType().cast<FIRRTLType>();
resultType = SubfieldOp::getResultType(resultType, fieldName);
resultType =
SubfieldOp::getResultType(resultType, fieldName, translateLocation(loc));
if (!resultType) {
// TODO(QoI): This error would be nicer with a .fir pretty print of the
// type.
@ -938,7 +939,8 @@ ParseResult FIRStmtParser::parsePostFixIntSubscript(Value &result,
// Make sure the index expression is valid and compute the result type for the
// expression.
auto resultType = result.getType().cast<FIRRTLType>();
resultType = SubindexOp::getResultType(resultType, indexNo);
resultType = SubindexOp::getResultType(resultType, indexNo,
translateLocation(indexLoc));
if (!resultType) {
// TODO(QoI): This error would be nicer with a .fir pretty print of the
// type.
@ -979,7 +981,8 @@ ParseResult FIRStmtParser::parsePostFixDynamicSubscript(Value &result,
// Make sure the index expression is valid and compute the result type for the
// expression.
auto resultType = result.getType().cast<FIRRTLType>();
resultType = SubaccessOp::getResultType(resultType, indexType);
resultType = SubaccessOp::getResultType(resultType, indexType,
translateLocation(indexLoc));
if (!resultType) {
// TODO(QoI): This error would be nicer with a .fir pretty print of the
// type.
@ -1086,7 +1089,8 @@ ParseResult FIRStmtParser::parsePrimExp(Value &result, SubOpVector &subOps) {
#define TOK_LPKEYWORD_PRIM(SPELLING, CLASS) \
case FIRToken::lp_##SPELLING: { \
auto resultTy = CLASS::getResultType(opTypes, integers); \
auto resultTy = \
CLASS::getResultType(opTypes, integers, translateLocation(loc)); \
if (!resultTy) \
return typeError(#SPELLING); \
result = builder.create<CLASS>(translateLocation(loc), resultTy, \

View File

@ -112,3 +112,13 @@ circuit test :
module invalid_name :
input bf: { flip int_1 : UInt<1>, int_out : UInt<2>}
node n4 = add(bf, bf) ; expected-error {{invalid input types for 'add'}}
;// -----
circuit test :
module invalid_bits:
input a: UInt<8>
output b: UInt<4>
; expected-error@+2 {{high must be equal or greater than low, but got high = 4, low = 7}};
; expected-error@+1 {{invalid input types for 'bits': '!firrtl.uint<8>'}}
b <= bits(a, 4, 7)