[mlir] Add value_begin/value_end methods to DenseElementsAttr

Currently DenseElementsAttr only exposes the ability to get the full range of values for a given type T, but there are many situations where we just want the beginning/end iterator. This revision adds proper value_begin/value_end methods for all of the supported T types, and also cleans up a bit of the interface.

Differential Revision: https://reviews.llvm.org/D104173
This commit is contained in:
River Riddle 2021-09-21 01:40:22 +00:00
parent 4f21152af1
commit 0cb5d7fc7f
18 changed files with 300 additions and 169 deletions

View File

@ -165,7 +165,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
auto valueIt = constantValue.getValues<FloatAttr>().begin();
auto valueIt = constantValue.value_begin<FloatAttr>();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.

View File

@ -164,7 +164,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
auto valueIt = constantValue.getValues<FloatAttr>().begin();
auto valueIt = constantValue.value_begin<FloatAttr>();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.

View File

@ -165,7 +165,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
auto valueIt = constantValue.getValues<FloatAttr>().begin();
auto valueIt = constantValue.value_begin<FloatAttr>();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.

View File

@ -58,8 +58,8 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
auto lhs = operands[0].cast<ElementsAttr>();
auto rhs = operands[1].cast<ElementsAttr>();
auto lhsIt = lhs.getValues<ElementValueT>().begin();
auto rhsIt = rhs.getValues<ElementValueT>().begin();
auto lhsIt = lhs.value_begin<ElementValueT>();
auto rhsIt = rhs.value_begin<ElementValueT>();
SmallVector<ElementValueT, 4> elementResults;
elementResults.reserve(lhs.getNumElements());
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)

View File

@ -51,6 +51,9 @@ public:
/// with static shape.
ShapedType getType() const;
/// Return the element type of this ElementsAttr.
Type getElementType() const;
/// Return the value at the given index. The index is expected to refer to a
/// valid element.
Attribute getValue(ArrayRef<uint64_t> index) const;
@ -65,8 +68,9 @@ public:
/// Return the elements of this attribute as a value of type 'T'. Note:
/// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
/// iteration.
template <typename T>
iterator_range<T> getValues() const;
template <typename T> iterator_range<T> getValues() const;
template <typename T> iterator<T> value_begin() const;
template <typename T> iterator<T> value_end() const;
/// Return if the given 'index' refers to a valid element in this attribute.
bool isValidIndex(ArrayRef<uint64_t> index) const;
@ -417,7 +421,7 @@ public:
T>::type
getSplatValue() const {
assert(isSplat() && "expected the attribute to be a splat");
return *getValues<T>().begin();
return *value_begin<T>();
}
/// Return the splat value for derived attribute element types.
template <typename T>
@ -436,15 +440,21 @@ public:
template <typename T>
T getValue(ArrayRef<uint64_t> index) const {
// Skip to the element corresponding to the flattened index.
return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
return getFlatValue<T>(getFlattenedIndex(index));
}
/// Return the value at the given flattened index.
template <typename T> T getFlatValue(uint64_t index) const {
return *std::next(value_begin<T>(), index);
}
/// Return the held element values as a range of integer or floating-point
/// values.
template <typename T, typename = typename std::enable_if<
(!std::is_same<T, bool>::value &&
template <typename T>
using IntFloatValueTemplateCheckT =
typename std::enable_if<(!std::is_same<T, bool>::value &&
std::numeric_limits<T>::is_integer) ||
is_valid_cpp_fp_type<T>::value>::type>
is_valid_cpp_fp_type<T>::value>::type;
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
llvm::iterator_range<ElementIterator<T>> getValues() const {
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
std::numeric_limits<T>::is_signed));
@ -453,13 +463,27 @@ public:
return {ElementIterator<T>(rawData, splat, 0),
ElementIterator<T>(rawData, splat, getNumElements())};
}
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
ElementIterator<T> value_begin() const {
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
std::numeric_limits<T>::is_signed));
return ElementIterator<T>(getRawData().data(), isSplat(), 0);
}
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
ElementIterator<T> value_end() const {
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
std::numeric_limits<T>::is_signed));
return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
}
/// Return the held element values as a range of std::complex.
template <typename T, typename ElementT = typename T::value_type,
typename = typename std::enable_if<
detail::is_complex_t<T>::value &&
template <typename T, typename ElementT>
using ComplexValueTemplateCheckT =
typename std::enable_if<detail::is_complex_t<T>::value &&
(std::numeric_limits<ElementT>::is_integer ||
is_valid_cpp_fp_type<ElementT>::value)>::type>
is_valid_cpp_fp_type<ElementT>::value)>::type;
template <typename T, typename ElementT = typename T::value_type,
typename = ComplexValueTemplateCheckT<T, ElementT>>
llvm::iterator_range<ElementIterator<T>> getValues() const {
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
std::numeric_limits<ElementT>::is_signed));
@ -468,10 +492,26 @@ public:
return {ElementIterator<T>(rawData, splat, 0),
ElementIterator<T>(rawData, splat, getNumElements())};
}
template <typename T, typename ElementT = typename T::value_type,
typename = ComplexValueTemplateCheckT<T, ElementT>>
ElementIterator<T> value_begin() const {
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
std::numeric_limits<ElementT>::is_signed));
return ElementIterator<T>(getRawData().data(), isSplat(), 0);
}
template <typename T, typename ElementT = typename T::value_type,
typename = ComplexValueTemplateCheckT<T, ElementT>>
ElementIterator<T> value_end() const {
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
std::numeric_limits<ElementT>::is_signed));
return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
}
/// Return the held element values as a range of StringRef.
template <typename T, typename = typename std::enable_if<
std::is_same<T, StringRef>::value>::type>
template <typename T>
using StringRefValueTemplateCheckT =
typename std::enable_if<std::is_same<T, StringRef>::value>::type;
template <typename T, typename = StringRefValueTemplateCheckT<T>>
llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
auto stringRefs = getRawStringData();
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
@ -479,80 +519,156 @@ public:
return {ElementIterator<StringRef>(ptr, splat, 0),
ElementIterator<StringRef>(ptr, splat, getNumElements())};
}
template <typename T, typename = StringRefValueTemplateCheckT<T>>
ElementIterator<StringRef> value_begin() const {
const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
return ElementIterator<StringRef>(ptr, isSplat(), 0);
}
template <typename T, typename = StringRefValueTemplateCheckT<T>>
ElementIterator<StringRef> value_end() const {
const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
return ElementIterator<StringRef>(ptr, isSplat(), getNumElements());
}
/// Return the held element values as a range of Attributes.
llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
template <typename T, typename = typename std::enable_if<
std::is_same<T, Attribute>::value>::type>
template <typename T>
using AttributeValueTemplateCheckT =
typename std::enable_if<std::is_same<T, Attribute>::value>::type;
template <typename T, typename = AttributeValueTemplateCheckT<T>>
llvm::iterator_range<AttributeElementIterator> getValues() const {
return getAttributeValues();
return {value_begin<Attribute>(), value_end<Attribute>()};
}
template <typename T, typename = AttributeValueTemplateCheckT<T>>
AttributeElementIterator value_begin() const {
return AttributeElementIterator(*this, 0);
}
template <typename T, typename = AttributeValueTemplateCheckT<T>>
AttributeElementIterator value_end() const {
return AttributeElementIterator(*this, getNumElements());
}
AttributeElementIterator attr_value_begin() const;
AttributeElementIterator attr_value_end() const;
/// Return the held element values a range of T, where T is a derived
/// attribute type.
template <typename T>
using DerivedAttrValueTemplateCheckT =
typename std::enable_if<std::is_base_of<Attribute, T>::value &&
!std::is_same<Attribute, T>::value>::type;
template <typename T>
using DerivedAttributeElementIterator =
llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
template <typename T, typename = typename std::enable_if<
std::is_base_of<Attribute, T>::value &&
!std::is_same<Attribute, T>::value>::type>
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
return llvm::map_range(getAttributeValues(),
return llvm::map_range(getValues<Attribute>(),
static_cast<T (*)(Attribute)>(castFn));
}
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
DerivedAttributeElementIterator<T> value_begin() const {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
return {value_begin<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
}
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
DerivedAttributeElementIterator<T> value_end() const {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
return {value_end<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
}
/// Return the held element values as a range of bool. The element type of
/// this attribute must be of integer type of bitwidth 1.
llvm::iterator_range<BoolElementIterator> getBoolValues() const;
template <typename T, typename = typename std::enable_if<
std::is_same<T, bool>::value>::type>
template <typename T>
using BoolValueTemplateCheckT =
typename std::enable_if<std::is_same<T, bool>::value>::type;
template <typename T, typename = BoolValueTemplateCheckT<T>>
llvm::iterator_range<BoolElementIterator> getValues() const {
return getBoolValues();
assert(isValidBool() && "bool is not the value of this elements attribute");
return {BoolElementIterator(*this, 0),
BoolElementIterator(*this, getNumElements())};
}
template <typename T, typename = BoolValueTemplateCheckT<T>>
BoolElementIterator value_begin() const {
assert(isValidBool() && "bool is not the value of this elements attribute");
return BoolElementIterator(*this, 0);
}
template <typename T, typename = BoolValueTemplateCheckT<T>>
BoolElementIterator value_end() const {
assert(isValidBool() && "bool is not the value of this elements attribute");
return BoolElementIterator(*this, getNumElements());
}
/// Return the held element values as a range of APInts. The element type of
/// this attribute must be of integer type.
llvm::iterator_range<IntElementIterator> getIntValues() const;
template <typename T, typename = typename std::enable_if<
std::is_same<T, APInt>::value>::type>
template <typename T>
using APIntValueTemplateCheckT =
typename std::enable_if<std::is_same<T, APInt>::value>::type;
template <typename T, typename = APIntValueTemplateCheckT<T>>
llvm::iterator_range<IntElementIterator> getValues() const {
return getIntValues();
assert(getElementType().isIntOrIndex() && "expected integral type");
return {raw_int_begin(), raw_int_end()};
}
template <typename T, typename = APIntValueTemplateCheckT<T>>
IntElementIterator value_begin() const {
assert(getElementType().isIntOrIndex() && "expected integral type");
return raw_int_begin();
}
template <typename T, typename = APIntValueTemplateCheckT<T>>
IntElementIterator value_end() const {
assert(getElementType().isIntOrIndex() && "expected integral type");
return raw_int_end();
}
IntElementIterator int_value_begin() const;
IntElementIterator int_value_end() const;
/// Return the held element values as a range of complex APInts. The element
/// type of this attribute must be a complex of integer type.
llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
template <typename T, typename = typename std::enable_if<
std::is_same<T, std::complex<APInt>>::value>::type>
template <typename T>
using ComplexAPIntValueTemplateCheckT = typename std::enable_if<
std::is_same<T, std::complex<APInt>>::value>::type;
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
llvm::iterator_range<ComplexIntElementIterator> getValues() const {
return getComplexIntValues();
}
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
ComplexIntElementIterator value_begin() const {
return complex_value_begin();
}
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
ComplexIntElementIterator value_end() const {
return complex_value_end();
}
/// Return the held element values as a range of APFloat. The element type of
/// this attribute must be of float type.
llvm::iterator_range<FloatElementIterator> getFloatValues() const;
template <typename T, typename = typename std::enable_if<
std::is_same<T, APFloat>::value>::type>
template <typename T>
using APFloatValueTemplateCheckT =
typename std::enable_if<std::is_same<T, APFloat>::value>::type;
template <typename T, typename = APFloatValueTemplateCheckT<T>>
llvm::iterator_range<FloatElementIterator> getValues() const {
return getFloatValues();
}
FloatElementIterator float_value_begin() const;
FloatElementIterator float_value_end() const;
template <typename T, typename = APFloatValueTemplateCheckT<T>>
FloatElementIterator value_begin() const {
return float_value_begin();
}
template <typename T, typename = APFloatValueTemplateCheckT<T>>
FloatElementIterator value_end() const {
return float_value_end();
}
/// Return the held element values as a range of complex APFloat. The element
/// type of this attribute must be a complex of float type.
llvm::iterator_range<ComplexFloatElementIterator>
getComplexFloatValues() const;
template <typename T, typename = typename std::enable_if<std::is_same<
T, std::complex<APFloat>>::value>::type>
template <typename T>
using ComplexAPFloatValueTemplateCheckT = typename std::enable_if<
std::is_same<T, std::complex<APFloat>>::value>::type;
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
return getComplexFloatValues();
}
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
ComplexFloatElementIterator value_begin() const {
return complex_float_value_begin();
}
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
ComplexFloatElementIterator value_end() const {
return complex_float_value_end();
}
/// Return the raw storage data held by this attribute. Users should generally
/// not use this directly, as the internal storage format is not always in the
@ -590,13 +706,25 @@ public:
function_ref<APInt(const APFloat &)> mapping) const;
protected:
/// Get iterators to the raw APInt values for each element in this attribute.
/// Iterators to various elements that require out-of-line definition. These
/// are hidden from the user to encourage consistent use of the
/// getValues/value_begin/value_end API.
IntElementIterator raw_int_begin() const {
return IntElementIterator(*this, 0);
}
IntElementIterator raw_int_end() const {
return IntElementIterator(*this, getNumElements());
}
llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
ComplexIntElementIterator complex_value_begin() const;
ComplexIntElementIterator complex_value_end() const;
llvm::iterator_range<FloatElementIterator> getFloatValues() const;
FloatElementIterator float_value_begin() const;
FloatElementIterator float_value_end() const;
llvm::iterator_range<ComplexFloatElementIterator>
getComplexFloatValues() const;
ComplexFloatElementIterator complex_float_value_begin() const;
ComplexFloatElementIterator complex_float_value_end() const;
/// Overload of the raw 'get' method that asserts that the given type is of
/// complex type. This method is used to verify type invariants that the
@ -616,11 +744,8 @@ protected:
/// Check the information for a C++ data type, check if this type is valid for
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
bool isValidBool() const { return getElementType().isInteger(1); }
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
/// Check the information for a C++ data type, check if this type is valid for
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
};
@ -806,7 +931,7 @@ template <typename T>
auto SparseElementsAttr::getValues() const
-> llvm::iterator_range<iterator<T>> {
auto zeroValue = getZeroValue<T>();
auto valueIt = getValues().getValues<T>().begin();
auto valueIt = getValues().value_begin<T>();
const std::vector<ptrdiff_t> flatSparseIndices(getFlattenedSparseIndices());
std::function<T(ptrdiff_t)> mapFn =
[flatSparseIndices{std::move(flatSparseIndices)},
@ -821,6 +946,14 @@ auto SparseElementsAttr::getValues() const
};
return llvm::map_range(llvm::seq<ptrdiff_t>(0, getNumElements()), mapFn);
}
template <typename T>
auto SparseElementsAttr::value_begin() const -> iterator<T> {
return getValues<T>().begin();
}
template <typename T>
auto SparseElementsAttr::value_end() const -> iterator<T> {
return getValues<T>().end();
}
namespace detail {
/// This class represents a general iterator over the values of an ElementsAttr.
@ -833,8 +966,7 @@ class ElementsAttrIterator
// NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype'
// inside of a conversion operator.
using DenseIteratorT = typename std::enable_if<
true,
decltype(std::declval<DenseElementsAttr>().getValues<T>().begin())>::type;
true, decltype(std::declval<DenseElementsAttr>().value_begin<T>())>::type;
using SparseIteratorT = SparseElementsAttr::iterator<T>;
/// A union containing the specific iterators for each derived attribute kind.
@ -960,6 +1092,21 @@ auto ElementsAttr::getValues() const -> iterator_range<T> {
llvm_unreachable("unexpected attribute kind");
}
template <typename T> auto ElementsAttr::value_begin() const -> iterator<T> {
if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
return iterator<T>(*this, denseAttr.value_begin<T>());
if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
return iterator<T>(*this, sparseAttr.value_begin<T>());
llvm_unreachable("unexpected attribute kind");
}
template <typename T> auto ElementsAttr::value_end() const -> iterator<T> {
if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
return iterator<T>(*this, denseAttr.value_end<T>());
if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
return iterator<T>(*this, sparseAttr.value_end<T>());
llvm_unreachable("unexpected attribute kind");
}
} // end namespace mlir.
//===----------------------------------------------------------------------===//

View File

@ -721,6 +721,8 @@ def Builtin_SparseElementsAttr
/// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types,
/// etc.
template <typename T> llvm::iterator_range<iterator<T>> getValues() const;
template <typename T> iterator<T> value_begin() const;
template <typename T> iterator<T> value_end() const;
/// Return the value of the element at the given index. The 'index' is
/// expected to refer to a valid element.

View File

@ -505,48 +505,36 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
// Indexed accessors.
bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
pos);
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<bool>(pos);
}
int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>().begin() +
pos);
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int8_t>(pos);
}
uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>().begin() +
pos);
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint8_t>(pos);
}
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
pos);
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int32_t>(pos);
}
uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
return *(
unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
pos);
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint32_t>(pos);
}
int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
pos);
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int64_t>(pos);
}
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
return *(
unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
pos);
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint64_t>(pos);
}
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
pos);
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<float>(pos);
}
double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
pos);
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<double>(pos);
}
MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
intptr_t pos) {
return wrap(
*(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
pos));
unwrap(attr).cast<DenseElementsAttr>().getFlatValue<StringRef>(pos));
}
//===----------------------------------------------------------------------===//

View File

@ -127,7 +127,7 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
return failure();
if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1),
[](const APInt &size) { return !size.isOneValue(); }))
return failure();

View File

@ -558,9 +558,9 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
if (srcElemType != dstElemType) {
SmallVector<Attribute, 8> elements;
if (srcElemType.isa<FloatType>()) {
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
FloatAttr dstAttr = convertFloatAttr(
srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
FloatAttr dstAttr =
convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@ -568,10 +568,9 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
} else if (srcElemType.isInteger(1)) {
return failure();
} else {
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
IntegerAttr dstAttr =
convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
dstElemType.cast<IntegerType>(), rewriter);
for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
IntegerAttr dstAttr = convertIntegerAttr(
srcAttr, dstElemType.cast<IntegerType>(), rewriter);
if (!dstAttr)
return failure();
elements.push_back(dstAttr);

View File

@ -1610,7 +1610,7 @@ public:
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultTy.getRank());
for (auto permutation : llvm::enumerate(perms.getIntValues())) {
for (auto permutation : llvm::enumerate(perms.getValues<APInt>())) {
inputExprs[permutation.value().getZExtValue()] =
rewriter.getAffineDimExpr(permutation.index());
}

View File

@ -337,11 +337,12 @@ void gpu::addAsyncDependency(Operation *op, Value token) {
auto attrName =
OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr();
auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName);
// Async dependencies is the only variadic operand.
if (!sizeAttr)
return; // Async dependencies is the only variadic operand.
SmallVector<int32_t, 8> sizes;
for (auto size : sizeAttr.getIntValues())
sizes.push_back(size.getSExtValue());
return;
SmallVector<int32_t, 8> sizes(sizeAttr.getValues<int32_t>());
++sizes.front();
op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes));
}

View File

@ -1825,8 +1825,9 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
// and hence was replaced.
if (complexElementType.isa<IntegerType>()) {
bool isSigned = !complexElementType.isUnsignedInteger();
auto valueIt = attr.value_begin<std::complex<APInt>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(attr.getComplexIntValues().begin() + index);
auto complexValue = *(valueIt + index);
os << "(";
printDenseIntElement(complexValue.real(), os, isSigned);
os << ",";
@ -1834,8 +1835,9 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
os << ")";
});
} else {
auto valueIt = attr.value_begin<std::complex<APFloat>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(attr.getComplexFloatValues().begin() + index);
auto complexValue = *(valueIt + index);
os << "(";
printFloatValue(complexValue.real(), os);
os << ",";
@ -1845,15 +1847,15 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
}
} else if (elementType.isIntOrIndex()) {
bool isSigned = !elementType.isUnsignedInteger();
auto intValues = attr.getIntValues();
auto valueIt = attr.value_begin<APInt>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
printDenseIntElement(*(intValues.begin() + index), os, isSigned);
printDenseIntElement(*(valueIt + index), os, isSigned);
});
} else {
assert(elementType.isa<FloatType>() && "unexpected element type");
auto floatValues = attr.getFloatValues();
auto valueIt = attr.value_begin<APFloat>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
printFloatValue(*(floatValues.begin() + index), os);
printFloatValue(*(valueIt + index), os);
});
}
}

View File

@ -390,6 +390,8 @@ ShapedType ElementsAttr::getType() const {
return Attribute::getType().cast<ShapedType>();
}
Type ElementsAttr::getElementType() const { return getType().getElementType(); }
/// Returns the number of elements held by this attribute.
int64_t ElementsAttr::getNumElements() const {
return getType().getNumElements();
@ -635,7 +637,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
Type eltTy = owner.getType().getElementType();
Type eltTy = owner.getElementType();
if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
if (eltTy.isa<IndexType>())
@ -690,7 +692,7 @@ DenseElementsAttr::IntElementIterator::IntElementIterator(
DenseElementsAttr attr, size_t dataIndex)
: DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
attr.getRawData().data(), attr.isSplat(), dataIndex),
bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
APInt DenseElementsAttr::IntElementIterator::operator*() const {
return readBits(getData(),
@ -707,7 +709,7 @@ DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
std::complex<APInt>, std::complex<APInt>,
std::complex<APInt>>(
attr.getRawData().data(), attr.isSplat(), dataIndex) {
auto complexType = attr.getType().getElementType().cast<ComplexType>();
auto complexType = attr.getElementType().cast<ComplexType>();
bitWidth = getDenseElementBitWidth(complexType.getElementType());
}
@ -930,21 +932,15 @@ DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
isInt, isSigned);
}
/// A method used to verify specific type invariants that the templatized 'get'
/// method cannot.
bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
bool isSigned) const {
return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
isSigned);
return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
}
/// Check the information for a C++ data type, check if this type is valid for
/// the current attribute.
bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
bool isSigned) const {
return ::isValidIntOrFloat(
getType().getElementType().cast<ComplexType>().getElementType(),
dataEltSize / 2, isInt, isSigned);
getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
isInt, isSigned);
}
/// Returns true if this attribute corresponds to a splat, i.e. if all element
@ -953,76 +949,69 @@ bool DenseElementsAttr::isSplat() const {
return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
}
/// Return the held element values as a range of Attributes.
auto DenseElementsAttr::getAttributeValues() const
-> llvm::iterator_range<AttributeElementIterator> {
return {attr_value_begin(), attr_value_end()};
}
auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
return AttributeElementIterator(*this, 0);
}
auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
return AttributeElementIterator(*this, getNumElements());
/// Return if the given complex type has an integer element type.
static bool isComplexOfIntType(Type type) {
return type.cast<ComplexType>().getElementType().isa<IntegerType>();
}
/// Return the held element values as a range of bool. The element type of
/// this attribute must be of integer type of bitwidth 1.
auto DenseElementsAttr::getBoolValues() const
-> llvm::iterator_range<BoolElementIterator> {
auto eltType = getType().getElementType().dyn_cast<IntegerType>();
assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
(void)eltType;
return {BoolElementIterator(*this, 0),
BoolElementIterator(*this, getNumElements())};
}
/// Return the held element values as a range of APInts. The element type of
/// this attribute must be of integer type.
auto DenseElementsAttr::getIntValues() const
-> llvm::iterator_range<IntElementIterator> {
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
return {raw_int_begin(), raw_int_end()};
}
auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
return raw_int_begin();
}
auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
return raw_int_end();
}
auto DenseElementsAttr::getComplexIntValues() const
-> llvm::iterator_range<ComplexIntElementIterator> {
Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
(void)eltTy;
assert(eltTy.isa<IntegerType>() && "expected complex integral type");
assert(isComplexOfIntType(getElementType()) &&
"expected complex integral type");
return {ComplexIntElementIterator(*this, 0),
ComplexIntElementIterator(*this, getNumElements())};
}
auto DenseElementsAttr::complex_value_begin() const
-> ComplexIntElementIterator {
assert(isComplexOfIntType(getElementType()) &&
"expected complex integral type");
return ComplexIntElementIterator(*this, 0);
}
auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
assert(isComplexOfIntType(getElementType()) &&
"expected complex integral type");
return ComplexIntElementIterator(*this, getNumElements());
}
/// Return the held element values as a range of APFloat. The element type of
/// this attribute must be of float type.
auto DenseElementsAttr::getFloatValues() const
-> llvm::iterator_range<FloatElementIterator> {
auto elementType = getType().getElementType().cast<FloatType>();
auto elementType = getElementType().cast<FloatType>();
const auto &elementSemantics = elementType.getFloatSemantics();
return {FloatElementIterator(elementSemantics, raw_int_begin()),
FloatElementIterator(elementSemantics, raw_int_end())};
}
auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
return getFloatValues().begin();
auto elementType = getElementType().cast<FloatType>();
return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
}
auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
return getFloatValues().end();
auto elementType = getElementType().cast<FloatType>();
return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
}
auto DenseElementsAttr::getComplexFloatValues() const
-> llvm::iterator_range<ComplexFloatElementIterator> {
Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
Type eltTy = getElementType().cast<ComplexType>().getElementType();
assert(eltTy.isa<FloatType>() && "expected complex float type");
const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
return {{semantics, {*this, 0}},
{semantics, {*this, static_cast<size_t>(getNumElements())}}};
}
auto DenseElementsAttr::complex_float_value_begin() const
-> ComplexFloatElementIterator {
Type eltTy = getElementType().cast<ComplexType>().getElementType();
assert(eltTy.isa<FloatType>() && "expected complex float type");
return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
}
auto DenseElementsAttr::complex_float_value_end() const
-> ComplexFloatElementIterator {
Type eltTy = getElementType().cast<ComplexType>().getElementType();
assert(eltTy.isa<FloatType>() && "expected complex float type");
return {eltTy.cast<FloatType>().getFloatSemantics(),
{*this, static_cast<size_t>(getNumElements())}};
}
/// Return the raw storage data held by this attribute.
ArrayRef<char> DenseElementsAttr::getRawData() const {
@ -1374,19 +1363,19 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
/// Get a zero APFloat for the given sparse attribute.
APFloat SparseElementsAttr::getZeroAPFloat() const {
auto eltType = getType().getElementType().cast<FloatType>();
auto eltType = getElementType().cast<FloatType>();
return APFloat(eltType.getFloatSemantics());
}
/// Get a zero APInt for the given sparse attribute.
APInt SparseElementsAttr::getZeroAPInt() const {
auto eltType = getType().getElementType().cast<IntegerType>();
auto eltType = getElementType().cast<IntegerType>();
return APInt::getZero(eltType.getWidth());
}
/// Get a zero attribute for the given attribute type.
Attribute SparseElementsAttr::getZeroAttr() const {
auto eltType = getType().getElementType();
auto eltType = getElementType();
// Handle floating point elements.
if (eltType.isa<FloatType>())

View File

@ -1024,7 +1024,7 @@ LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op,
return op->emitOpError("requires 1D i32 elements attribute '")
<< attrName << "'";
if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) {
if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) {
return !element.isNonNegative();
}))
return op->emitOpError("'")

View File

@ -51,7 +51,7 @@ void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
auto dattr = attr.cast<DenseIntElementsAttr>();
res.clear();
res.reserve(dattr.size());
for (auto it : dattr.getIntValues())
for (auto it : dattr.getValues<APInt>())
res.push_back(it.getSExtValue());
} else {
auto vals = val.get<ShapedTypeComponents *>()->getDims();
@ -71,7 +71,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
return t.cast<ShapedType>().getDimSize(index);
if (auto attr = val.dyn_cast<Attribute>())
return attr.cast<DenseIntElementsAttr>()
.getValue<APInt>({static_cast<uint64_t>(index)})
.getFlatValue<APInt>(index)
.getSExtValue();
auto *stc = val.get<ShapedTypeComponents *>();
return stc->getDims()[index];
@ -94,7 +94,7 @@ bool ShapeAdaptor::hasStaticShape() const {
return t.cast<ShapedType>().hasStaticShape();
if (auto attr = val.dyn_cast<Attribute>()) {
auto dattr = attr.cast<DenseIntElementsAttr>();
for (auto index : dattr.getIntValues())
for (auto index : dattr.getValues<APInt>())
if (ShapedType::isDynamic(index.getSExtValue()))
return false;
return true;
@ -115,7 +115,7 @@ int64_t ShapeAdaptor::getNumElements() const {
if (auto attr = val.dyn_cast<Attribute>()) {
auto dattr = attr.cast<DenseIntElementsAttr>();
int64_t num = 1;
for (auto index : dattr.getIntValues()) {
for (auto index : dattr.getValues<APInt>()) {
num *= index.getZExtValue();
assert(num >= 0 && "integer overflow in element count computation");
}

View File

@ -294,7 +294,8 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
if (!nested)
return nullptr;
values.append(nested.attr_value_begin(), nested.attr_value_end());
values.append(nested.value_begin<Attribute>(),
nested.value_end<Attribute>());
}
return DenseElementsAttr::get(outerType, values);

View File

@ -83,12 +83,14 @@ const char *opSegmentSizeAttrInitCode = R"(
auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>();
)";
const char *attrSizedSegmentValueRangeCalcCode = R"(
auto sizeAttrValues = sizeAttr.getValues<uint32_t>();
const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>();
if (sizeAttr.isSplat())
return {*sizeAttrValueIt * index, *sizeAttrValueIt};
unsigned start = 0;
for (unsigned i = 0; i < index; ++i)
start += *(sizeAttrValues.begin() + i);
unsigned size = *(sizeAttrValues.begin() + index);
return {start, size};
start += sizeAttrValueIt[i];
return {start, sizeAttrValueIt[index]};
)";
// The logic to calculate the actual value range for a declared operand
// of an op with variadic of variadic operands within the OpAdaptor.

View File

@ -158,7 +158,7 @@ TEST(StructsGenTest, GetElements) {
auto denseAttr = returnedAttr.dyn_cast<mlir::DenseElementsAttr>();
ASSERT_TRUE(denseAttr);
for (const auto &valIndexIt : llvm::enumerate(denseAttr.getIntValues())) {
for (const auto &valIndexIt : llvm::enumerate(denseAttr.getValues<APInt>())) {
EXPECT_EQ(valIndexIt.value(), valIndexIt.index() + 1);
}
}