diff --git a/include/circt/Dialect/FIRRTL/OpExpressions.td b/include/circt/Dialect/FIRRTL/OpExpressions.td index a313dc5c37..b9ee07428b 100644 --- a/include/circt/Dialect/FIRRTL/OpExpressions.td +++ b/include/circt/Dialect/FIRRTL/OpExpressions.td @@ -50,7 +50,8 @@ def SubfieldOp : FIRRTLOp<"subfield", [NoSideEffect]> { /// Compute the result of a Subfield operation on a value of the specified /// type and extracting the specified field name. If the request is /// invalid, then a null type is returned. - static FIRRTLType getResultType(FIRRTLType inType, StringRef fieldName); + static FIRRTLType getResultType(FIRRTLType inType, StringRef fieldName, + Location loc); }]; } @@ -77,7 +78,8 @@ def SubindexOp : FIRRTLOp<"subindex", [NoSideEffect]> { let extraClassDeclaration = [{ /// Compute the result of a Subindex operation on a value of the specified /// type. If the request is invalid, then a null type is returned. - static FIRRTLType getResultType(FIRRTLType inType, unsigned fieldIdx); + static FIRRTLType getResultType(FIRRTLType inType, unsigned fieldIdx, + Location loc); }]; } @@ -103,7 +105,8 @@ def SubaccessOp : FIRRTLOp<"subaccess", [NoSideEffect]> { let extraClassDeclaration = [{ /// Compute the result of a Subaccess operation on a value of the specified /// type. If the request is invalid, then a null type is returned. - static FIRRTLType getResultType(FIRRTLType baseType, FIRRTLType indexType); + static FIRRTLType getResultType(FIRRTLType baseType, FIRRTLType indexType, + Location loc); }]; } //===----------------------------------------------------------------------===// @@ -133,14 +136,15 @@ class BinaryPrimOp(!strconcat(!cast([{ /// Return the result for inputs with the specified type, returning a null /// type if the input types are invalid. - static FIRRTLType getResultType(FIRRTLType lhs, FIRRTLType rhs) { + static FIRRTLType getResultType(FIRRTLType lhs, FIRRTLType rhs, + Location loc) { return }]), resultTypeFunction, !cast([{(lhs, rhs); } static FIRRTLType getResultType(ArrayRef inputs, - ArrayRef integers) { + ArrayRef integers, Location loc) { if (inputs.size() != 2 || !integers.empty()) return {}; - return getResultType(inputs[0], inputs[1]); + return getResultType(inputs[0], inputs[1], loc); } }]))); } @@ -197,14 +201,14 @@ class UnaryPrimOp(!strconcat(!cast([{ /// Return the result for inputs with the specified type, returning a null /// type if the input types are invalid. - static FIRRTLType getResultType(FIRRTLType input) { + static FIRRTLType getResultType(FIRRTLType input, Location loc) { return }]), resultTypeFunction, !cast([{(input); } static FIRRTLType getResultType(ArrayRef inputs, - ArrayRef integers) { + ArrayRef integers, Location loc) { if (inputs.size() != 1 || !integers.empty()) return {}; - return getResultType(inputs[0]); + return getResultType(inputs[0], loc); } }]))); } @@ -252,12 +256,12 @@ def BitsPrimOp : PrimOp<"bits"> { /// Return the result for inputs with the specified type, returning a null /// type if the input types are invalid. static FIRRTLType getResultType(FIRRTLType input, int32_t high, - int32_t low); + int32_t low, Location loc); static FIRRTLType getResultType(ArrayRef inputs, - ArrayRef integers) { + ArrayRef integers, Location loc) { if (inputs.size() != 1 || integers.size() != 2) return {}; - return getResultType(inputs[0], integers[0], integers[1]); + return getResultType(inputs[0], integers[0], integers[1], loc); } }]; @@ -277,12 +281,13 @@ def HeadPrimOp : PrimOp<"head"> { let extraClassDeclaration = [{ /// Return the result for inputs with the specified type, returning a null /// type if the input types are invalid. - static FIRRTLType getResultType(FIRRTLType input, int32_t amount); + static FIRRTLType getResultType(FIRRTLType input, int32_t amount, + Location loc); static FIRRTLType getResultType(ArrayRef inputs, - ArrayRef integers) { + ArrayRef integers, Location loc) { if (inputs.size() != 1 || integers.size() != 1) return {}; - return getResultType(inputs[0], integers[0]); + return getResultType(inputs[0], integers[0], loc); } }]; } @@ -301,12 +306,12 @@ def MuxPrimOp : PrimOp<"mux"> { /// Return the result for inputs with the specified type, returning a null /// type if the input types are invalid. static FIRRTLType getResultType(FIRRTLType sel, FIRRTLType high, - FIRRTLType low); + FIRRTLType low, Location loc); static FIRRTLType getResultType(ArrayRef inputs, - ArrayRef integers) { + ArrayRef integers, Location loc) { if (inputs.size() != 3 || integers.size() != 0) return {}; - return getResultType(inputs[0], inputs[1], inputs[2]); + return getResultType(inputs[0], inputs[1], inputs[2], loc); } }]; } @@ -330,12 +335,13 @@ def PadPrimOp : PrimOp<"pad"> { let extraClassDeclaration = [{ /// Return the result for inputs with the specified type, returning a null /// type if the input types are invalid. - static FIRRTLType getResultType(FIRRTLType input, int32_t amount); + static FIRRTLType getResultType(FIRRTLType input, int32_t amount, + Location loc); static FIRRTLType getResultType(ArrayRef inputs, - ArrayRef integers) { + ArrayRef integers, Location loc) { if (inputs.size() != 1 || integers.size() != 1) return {}; - return getResultType(inputs[0], integers[0]); + return getResultType(inputs[0], integers[0], loc); } }]; } @@ -351,12 +357,13 @@ class ShiftPrimOp : PrimOp { let extraClassDeclaration = [{ /// Return the result for inputs with the specified type, returning a null /// type if the input types are invalid. - static FIRRTLType getResultType(FIRRTLType input, int32_t amount); + static FIRRTLType getResultType(FIRRTLType input, int32_t amount, + Location loc); static FIRRTLType getResultType(ArrayRef inputs, - ArrayRef integers) { + ArrayRef integers, Location loc) { if (inputs.size() != 1 || integers.size() != 1) return {}; - return getResultType(inputs[0], integers[0]); + return getResultType(inputs[0], integers[0], loc); } }]; } @@ -385,12 +392,13 @@ def TailPrimOp : PrimOp<"tail"> { let extraClassDeclaration = [{ /// Return the result for inputs with the specified type, returning a null /// type if the input types are invalid. - static FIRRTLType getResultType(FIRRTLType input, int32_t amount); + static FIRRTLType getResultType(FIRRTLType input, int32_t amount, + Location loc); static FIRRTLType getResultType(ArrayRef inputs, - ArrayRef integers) { + ArrayRef integers, Location loc) { if (inputs.size() != 1 || integers.size() != 1) return {}; - return getResultType(inputs[0], integers[0]); + return getResultType(inputs[0], integers[0], loc); } }]; } diff --git a/lib/Dialect/FIRRTL/Ops.cpp b/lib/Dialect/FIRRTL/Ops.cpp index e17315bc5d..45a5592718 100644 --- a/lib/Dialect/FIRRTL/Ops.cpp +++ b/lib/Dialect/FIRRTL/Ops.cpp @@ -4,6 +4,7 @@ #include "circt/Dialect/FIRRTL/Ops.h" #include "circt/Dialect/FIRRTL/Visitors.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/StandardTypes.h" @@ -795,7 +796,8 @@ void ConstantOp::build(OpBuilder &builder, OperationState &result, IntType type, } // Return the result of a subfield operation. -FIRRTLType SubfieldOp::getResultType(FIRRTLType inType, StringRef fieldName) { +FIRRTLType SubfieldOp::getResultType(FIRRTLType inType, StringRef fieldName, + Location loc) { if (auto bundleType = inType.dyn_cast()) { for (auto &elt : bundleType.getElements()) { if (elt.first == fieldName) @@ -804,31 +806,33 @@ FIRRTLType SubfieldOp::getResultType(FIRRTLType inType, StringRef fieldName) { } if (auto flipType = inType.dyn_cast()) - if (auto subType = getResultType(flipType.getElementType(), fieldName)) + if (auto subType = getResultType(flipType.getElementType(), fieldName, loc)) return FlipType::get(subType); return {}; } -FIRRTLType SubindexOp::getResultType(FIRRTLType inType, unsigned fieldIdx) { +FIRRTLType SubindexOp::getResultType(FIRRTLType inType, unsigned fieldIdx, + Location loc) { if (auto vectorType = inType.dyn_cast()) if (fieldIdx < vectorType.getNumElements()) return vectorType.getElementType(); if (auto flipType = inType.dyn_cast()) - if (auto subType = getResultType(flipType.getElementType(), fieldIdx)) + if (auto subType = getResultType(flipType.getElementType(), fieldIdx, loc)) return FlipType::get(subType); return {}; } -FIRRTLType SubaccessOp::getResultType(FIRRTLType inType, FIRRTLType indexType) { +FIRRTLType SubaccessOp::getResultType(FIRRTLType inType, FIRRTLType indexType, + Location loc) { if (auto vectorType = inType.dyn_cast()) if (indexType.isa()) return vectorType.getElementType(); if (auto flipType = inType.dyn_cast()) - if (auto subType = getResultType(flipType.getElementType(), indexType)) + if (auto subType = getResultType(flipType.getElementType(), indexType, loc)) return FlipType::get(subType); return {}; @@ -1076,31 +1080,20 @@ FIRRTLType firrtl::getReductionResult(FIRRTLType input) { static LogicalResult verifyBitsPrimOp(BitsPrimOp bits) { uint32_t hi = bits.hi(), lo = bits.lo(); - // High must be >= low. - if (hi < lo) { - bits.emitError() - << "high must be equal or greater than low, but got high = " << hi - << ", low = " << lo; + auto expectedType = + BitsPrimOp::getResultType(bits.input().getType().cast(), + (int32_t)hi, (int32_t)lo, bits.getLoc()); + if (!expectedType) return failure(); - } - // Input width must be > high. - int32_t width = - bits.input().getType().cast().getBitWidthOrSentinel(); - if (width != -1 && int32_t(hi) >= width) { - bits.emitError() - << "high must be smaller than the width of input, but got high = " << hi - << ", width = " << width; - return failure(); - } - - // Result type should be int type with (high - low + 1) width. int32_t resultWidth = bits.result().getType().cast().getBitWidthOrSentinel(); - if (resultWidth != -1 && int32_t(hi - lo + 1) != resultWidth) { + int32_t expectedWidth = expectedType.cast().getBitWidthOrSentinel(); + + if (resultWidth != -1 && expectedWidth != resultWidth) { bits.emitError() << "width of the result type must be equal to (high - low " "+ 1), expected " - << hi - lo + 1 << " but got " << resultWidth; + << expectedWidth << " but got " << resultWidth; return failure(); } @@ -1108,30 +1101,51 @@ static LogicalResult verifyBitsPrimOp(BitsPrimOp bits) { } FIRRTLType BitsPrimOp::getResultType(FIRRTLType input, int32_t high, - int32_t low) { + int32_t low, Location loc) { auto inputi = input.dyn_cast(); - // High must be >= low and both most be non-negative. - if (!inputi || high < low || low < 0) + if (!inputi) { + mlir::emitError(loc) << "input type should be the int type but got " + << input; return {}; + } + + // High must be >= low and both most be non-negative. + if (high < low) { + mlir::emitError(loc) + << "high must be equal or greater than low, but got high = " << high + << ", low = " << low; + return {}; + } + + if (low < 0) { + mlir::emitError(loc) << "low must be non-negative but got" << low; + return {}; + } // If the input has staticly known width, check it. Both and low must be // strictly less than width. int32_t width = inputi.getWidthOrSentinel(); - if (width != -1 && high >= width) + if (width != -1 && high >= width) { + mlir::emitError(loc) + << "high must be smaller than the width of input, but got high = " + << high << ", width = " << width; return {}; + } return UIntType::get(input.getContext(), high - low + 1); } void BitsPrimOp::build(OpBuilder &builder, OperationState &result, Value input, unsigned high, unsigned low) { - auto type = getResultType(input.getType().cast(), high, low); + auto type = getResultType(input.getType().cast(), high, low, + result.location); assert(type && "invalid inputs building BitsPrimOp!"); build(builder, result, type, input, high, low); } -FIRRTLType HeadPrimOp::getResultType(FIRRTLType input, int32_t amount) { +FIRRTLType HeadPrimOp::getResultType(FIRRTLType input, int32_t amount, + Location loc) { auto inputi = input.dyn_cast(); if (amount <= 0 || !inputi) return {}; @@ -1145,7 +1159,7 @@ FIRRTLType HeadPrimOp::getResultType(FIRRTLType input, int32_t amount) { } FIRRTLType MuxPrimOp::getResultType(FIRRTLType sel, FIRRTLType high, - FIRRTLType low) { + FIRRTLType low, Location loc) { // Sel needs to be a one bit uint or an unknown width uint. auto selui = sel.dyn_cast(); if (!selui || selui.getWidthOrSentinel() > 1) @@ -1191,7 +1205,8 @@ FIRRTLType MuxPrimOp::getResultType(FIRRTLType sel, FIRRTLType high, return {}; } -FIRRTLType PadPrimOp::getResultType(FIRRTLType input, int32_t amount) { +FIRRTLType PadPrimOp::getResultType(FIRRTLType input, int32_t amount, + Location loc) { auto inputi = input.dyn_cast(); if (amount < 0 || !inputi) return {}; @@ -1204,7 +1219,8 @@ FIRRTLType PadPrimOp::getResultType(FIRRTLType input, int32_t amount) { return IntType::get(input.getContext(), inputi.isSigned(), width); } -FIRRTLType ShlPrimOp::getResultType(FIRRTLType input, int32_t amount) { +FIRRTLType ShlPrimOp::getResultType(FIRRTLType input, int32_t amount, + Location loc) { auto inputi = input.dyn_cast(); if (amount < 0 || !inputi) return {}; @@ -1216,7 +1232,8 @@ FIRRTLType ShlPrimOp::getResultType(FIRRTLType input, int32_t amount) { return IntType::get(input.getContext(), inputi.isSigned(), width); } -FIRRTLType ShrPrimOp::getResultType(FIRRTLType input, int32_t amount) { +FIRRTLType ShrPrimOp::getResultType(FIRRTLType input, int32_t amount, + Location loc) { auto inputi = input.dyn_cast(); if (amount < 0 || !inputi) return {}; @@ -1228,7 +1245,8 @@ FIRRTLType ShrPrimOp::getResultType(FIRRTLType input, int32_t amount) { return IntType::get(input.getContext(), inputi.isSigned(), width); } -FIRRTLType TailPrimOp::getResultType(FIRRTLType input, int32_t amount) { +FIRRTLType TailPrimOp::getResultType(FIRRTLType input, int32_t amount, + Location loc) { auto inputi = input.dyn_cast(); if (amount < 0 || !inputi) return {}; diff --git a/lib/FIRParser/FIRParser.cpp b/lib/FIRParser/FIRParser.cpp index faba71b781..6ba8de9388 100644 --- a/lib/FIRParser/FIRParser.cpp +++ b/lib/FIRParser/FIRParser.cpp @@ -902,7 +902,8 @@ ParseResult FIRStmtParser::parsePostFixFieldId(Value &result, // Make sure the field name matches up with the input value's type and // compute the result type for the expression. auto resultType = result.getType().cast(); - resultType = SubfieldOp::getResultType(resultType, fieldName); + resultType = + SubfieldOp::getResultType(resultType, fieldName, translateLocation(loc)); if (!resultType) { // TODO(QoI): This error would be nicer with a .fir pretty print of the // type. @@ -938,7 +939,8 @@ ParseResult FIRStmtParser::parsePostFixIntSubscript(Value &result, // Make sure the index expression is valid and compute the result type for the // expression. auto resultType = result.getType().cast(); - resultType = SubindexOp::getResultType(resultType, indexNo); + resultType = SubindexOp::getResultType(resultType, indexNo, + translateLocation(indexLoc)); if (!resultType) { // TODO(QoI): This error would be nicer with a .fir pretty print of the // type. @@ -979,7 +981,8 @@ ParseResult FIRStmtParser::parsePostFixDynamicSubscript(Value &result, // Make sure the index expression is valid and compute the result type for the // expression. auto resultType = result.getType().cast(); - resultType = SubaccessOp::getResultType(resultType, indexType); + resultType = SubaccessOp::getResultType(resultType, indexType, + translateLocation(indexLoc)); if (!resultType) { // TODO(QoI): This error would be nicer with a .fir pretty print of the // type. @@ -1086,7 +1089,8 @@ ParseResult FIRStmtParser::parsePrimExp(Value &result, SubOpVector &subOps) { #define TOK_LPKEYWORD_PRIM(SPELLING, CLASS) \ case FIRToken::lp_##SPELLING: { \ - auto resultTy = CLASS::getResultType(opTypes, integers); \ + auto resultTy = \ + CLASS::getResultType(opTypes, integers, translateLocation(loc)); \ if (!resultTy) \ return typeError(#SPELLING); \ result = builder.create(translateLocation(loc), resultTy, \ diff --git a/test/FIRParser/errors.fir b/test/FIRParser/errors.fir index c16d7ea902..621bb94a69 100644 --- a/test/FIRParser/errors.fir +++ b/test/FIRParser/errors.fir @@ -112,3 +112,13 @@ circuit test : module invalid_name : input bf: { flip int_1 : UInt<1>, int_out : UInt<2>} node n4 = add(bf, bf) ; expected-error {{invalid input types for 'add'}} + +;// ----- + +circuit test : + module invalid_bits: + input a: UInt<8> + output b: UInt<4> + ; expected-error@+2 {{high must be equal or greater than low, but got high = 4, low = 7}}; + ; expected-error@+1 {{invalid input types for 'bits': '!firrtl.uint<8>'}} + b <= bits(a, 4, 7) \ No newline at end of file