[mlir] Add value_begin/value_end methods to DenseElementsAttr
Currently DenseElementsAttr only exposes the ability to get the full range of values for a given type T, but there are many situations where we just want the beginning/end iterator. This revision adds proper value_begin/value_end methods for all of the supported T types, and also cleans up a bit of the interface. Differential Revision: https://reviews.llvm.org/D104173
This commit is contained in:
parent
4f21152af1
commit
0cb5d7fc7f
|
@ -165,7 +165,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
// functor recursively walks the dimensions of the constant shape,
|
||||
// generating a store when the recursion hits the base case.
|
||||
SmallVector<Value, 2> indices;
|
||||
auto valueIt = constantValue.getValues<FloatAttr>().begin();
|
||||
auto valueIt = constantValue.value_begin<FloatAttr>();
|
||||
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
|
||||
// The last dimension is the base case of the recursion, at this point
|
||||
// we store the element at the given index.
|
||||
|
|
|
@ -164,7 +164,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
// functor recursively walks the dimensions of the constant shape,
|
||||
// generating a store when the recursion hits the base case.
|
||||
SmallVector<Value, 2> indices;
|
||||
auto valueIt = constantValue.getValues<FloatAttr>().begin();
|
||||
auto valueIt = constantValue.value_begin<FloatAttr>();
|
||||
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
|
||||
// The last dimension is the base case of the recursion, at this point
|
||||
// we store the element at the given index.
|
||||
|
|
|
@ -165,7 +165,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
// functor recursively walks the dimensions of the constant shape,
|
||||
// generating a store when the recursion hits the base case.
|
||||
SmallVector<Value, 2> indices;
|
||||
auto valueIt = constantValue.getValues<FloatAttr>().begin();
|
||||
auto valueIt = constantValue.value_begin<FloatAttr>();
|
||||
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
|
||||
// The last dimension is the base case of the recursion, at this point
|
||||
// we store the element at the given index.
|
||||
|
|
|
@ -58,8 +58,8 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
|
|||
auto lhs = operands[0].cast<ElementsAttr>();
|
||||
auto rhs = operands[1].cast<ElementsAttr>();
|
||||
|
||||
auto lhsIt = lhs.getValues<ElementValueT>().begin();
|
||||
auto rhsIt = rhs.getValues<ElementValueT>().begin();
|
||||
auto lhsIt = lhs.value_begin<ElementValueT>();
|
||||
auto rhsIt = rhs.value_begin<ElementValueT>();
|
||||
SmallVector<ElementValueT, 4> elementResults;
|
||||
elementResults.reserve(lhs.getNumElements());
|
||||
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
|
||||
|
|
|
@ -51,6 +51,9 @@ public:
|
|||
/// with static shape.
|
||||
ShapedType getType() const;
|
||||
|
||||
/// Return the element type of this ElementsAttr.
|
||||
Type getElementType() const;
|
||||
|
||||
/// Return the value at the given index. The index is expected to refer to a
|
||||
/// valid element.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||
|
@ -65,8 +68,9 @@ public:
|
|||
/// Return the elements of this attribute as a value of type 'T'. Note:
|
||||
/// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
|
||||
/// iteration.
|
||||
template <typename T>
|
||||
iterator_range<T> getValues() const;
|
||||
template <typename T> iterator_range<T> getValues() const;
|
||||
template <typename T> iterator<T> value_begin() const;
|
||||
template <typename T> iterator<T> value_end() const;
|
||||
|
||||
/// Return if the given 'index' refers to a valid element in this attribute.
|
||||
bool isValidIndex(ArrayRef<uint64_t> index) const;
|
||||
|
@ -417,7 +421,7 @@ public:
|
|||
T>::type
|
||||
getSplatValue() const {
|
||||
assert(isSplat() && "expected the attribute to be a splat");
|
||||
return *getValues<T>().begin();
|
||||
return *value_begin<T>();
|
||||
}
|
||||
/// Return the splat value for derived attribute element types.
|
||||
template <typename T>
|
||||
|
@ -436,15 +440,21 @@ public:
|
|||
template <typename T>
|
||||
T getValue(ArrayRef<uint64_t> index) const {
|
||||
// Skip to the element corresponding to the flattened index.
|
||||
return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
|
||||
return getFlatValue<T>(getFlattenedIndex(index));
|
||||
}
|
||||
/// Return the value at the given flattened index.
|
||||
template <typename T> T getFlatValue(uint64_t index) const {
|
||||
return *std::next(value_begin<T>(), index);
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of integer or floating-point
|
||||
/// values.
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
(!std::is_same<T, bool>::value &&
|
||||
template <typename T>
|
||||
using IntFloatValueTemplateCheckT =
|
||||
typename std::enable_if<(!std::is_same<T, bool>::value &&
|
||||
std::numeric_limits<T>::is_integer) ||
|
||||
is_valid_cpp_fp_type<T>::value>::type>
|
||||
is_valid_cpp_fp_type<T>::value>::type;
|
||||
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<ElementIterator<T>> getValues() const {
|
||||
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
|
||||
std::numeric_limits<T>::is_signed));
|
||||
|
@ -453,13 +463,27 @@ public:
|
|||
return {ElementIterator<T>(rawData, splat, 0),
|
||||
ElementIterator<T>(rawData, splat, getNumElements())};
|
||||
}
|
||||
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
|
||||
ElementIterator<T> value_begin() const {
|
||||
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
|
||||
std::numeric_limits<T>::is_signed));
|
||||
return ElementIterator<T>(getRawData().data(), isSplat(), 0);
|
||||
}
|
||||
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
|
||||
ElementIterator<T> value_end() const {
|
||||
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
|
||||
std::numeric_limits<T>::is_signed));
|
||||
return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of std::complex.
|
||||
template <typename T, typename ElementT = typename T::value_type,
|
||||
typename = typename std::enable_if<
|
||||
detail::is_complex_t<T>::value &&
|
||||
template <typename T, typename ElementT>
|
||||
using ComplexValueTemplateCheckT =
|
||||
typename std::enable_if<detail::is_complex_t<T>::value &&
|
||||
(std::numeric_limits<ElementT>::is_integer ||
|
||||
is_valid_cpp_fp_type<ElementT>::value)>::type>
|
||||
is_valid_cpp_fp_type<ElementT>::value)>::type;
|
||||
template <typename T, typename ElementT = typename T::value_type,
|
||||
typename = ComplexValueTemplateCheckT<T, ElementT>>
|
||||
llvm::iterator_range<ElementIterator<T>> getValues() const {
|
||||
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
|
||||
std::numeric_limits<ElementT>::is_signed));
|
||||
|
@ -468,10 +492,26 @@ public:
|
|||
return {ElementIterator<T>(rawData, splat, 0),
|
||||
ElementIterator<T>(rawData, splat, getNumElements())};
|
||||
}
|
||||
template <typename T, typename ElementT = typename T::value_type,
|
||||
typename = ComplexValueTemplateCheckT<T, ElementT>>
|
||||
ElementIterator<T> value_begin() const {
|
||||
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
|
||||
std::numeric_limits<ElementT>::is_signed));
|
||||
return ElementIterator<T>(getRawData().data(), isSplat(), 0);
|
||||
}
|
||||
template <typename T, typename ElementT = typename T::value_type,
|
||||
typename = ComplexValueTemplateCheckT<T, ElementT>>
|
||||
ElementIterator<T> value_end() const {
|
||||
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
|
||||
std::numeric_limits<ElementT>::is_signed));
|
||||
return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of StringRef.
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_same<T, StringRef>::value>::type>
|
||||
template <typename T>
|
||||
using StringRefValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, StringRef>::value>::type;
|
||||
template <typename T, typename = StringRefValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
|
||||
auto stringRefs = getRawStringData();
|
||||
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
|
||||
|
@ -479,80 +519,156 @@ public:
|
|||
return {ElementIterator<StringRef>(ptr, splat, 0),
|
||||
ElementIterator<StringRef>(ptr, splat, getNumElements())};
|
||||
}
|
||||
template <typename T, typename = StringRefValueTemplateCheckT<T>>
|
||||
ElementIterator<StringRef> value_begin() const {
|
||||
const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
|
||||
return ElementIterator<StringRef>(ptr, isSplat(), 0);
|
||||
}
|
||||
template <typename T, typename = StringRefValueTemplateCheckT<T>>
|
||||
ElementIterator<StringRef> value_end() const {
|
||||
const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
|
||||
return ElementIterator<StringRef>(ptr, isSplat(), getNumElements());
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of Attributes.
|
||||
llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_same<T, Attribute>::value>::type>
|
||||
template <typename T>
|
||||
using AttributeValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, Attribute>::value>::type;
|
||||
template <typename T, typename = AttributeValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<AttributeElementIterator> getValues() const {
|
||||
return getAttributeValues();
|
||||
return {value_begin<Attribute>(), value_end<Attribute>()};
|
||||
}
|
||||
template <typename T, typename = AttributeValueTemplateCheckT<T>>
|
||||
AttributeElementIterator value_begin() const {
|
||||
return AttributeElementIterator(*this, 0);
|
||||
}
|
||||
template <typename T, typename = AttributeValueTemplateCheckT<T>>
|
||||
AttributeElementIterator value_end() const {
|
||||
return AttributeElementIterator(*this, getNumElements());
|
||||
}
|
||||
AttributeElementIterator attr_value_begin() const;
|
||||
AttributeElementIterator attr_value_end() const;
|
||||
|
||||
/// Return the held element values a range of T, where T is a derived
|
||||
/// attribute type.
|
||||
template <typename T>
|
||||
using DerivedAttrValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_base_of<Attribute, T>::value &&
|
||||
!std::is_same<Attribute, T>::value>::type;
|
||||
template <typename T>
|
||||
using DerivedAttributeElementIterator =
|
||||
llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_base_of<Attribute, T>::value &&
|
||||
!std::is_same<Attribute, T>::value>::type>
|
||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
|
||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||
return llvm::map_range(getAttributeValues(),
|
||||
return llvm::map_range(getValues<Attribute>(),
|
||||
static_cast<T (*)(Attribute)>(castFn));
|
||||
}
|
||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||
DerivedAttributeElementIterator<T> value_begin() const {
|
||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||
return {value_begin<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
|
||||
}
|
||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||
DerivedAttributeElementIterator<T> value_end() const {
|
||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||
return {value_end<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of bool. The element type of
|
||||
/// this attribute must be of integer type of bitwidth 1.
|
||||
llvm::iterator_range<BoolElementIterator> getBoolValues() const;
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_same<T, bool>::value>::type>
|
||||
template <typename T>
|
||||
using BoolValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, bool>::value>::type;
|
||||
template <typename T, typename = BoolValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<BoolElementIterator> getValues() const {
|
||||
return getBoolValues();
|
||||
assert(isValidBool() && "bool is not the value of this elements attribute");
|
||||
return {BoolElementIterator(*this, 0),
|
||||
BoolElementIterator(*this, getNumElements())};
|
||||
}
|
||||
template <typename T, typename = BoolValueTemplateCheckT<T>>
|
||||
BoolElementIterator value_begin() const {
|
||||
assert(isValidBool() && "bool is not the value of this elements attribute");
|
||||
return BoolElementIterator(*this, 0);
|
||||
}
|
||||
template <typename T, typename = BoolValueTemplateCheckT<T>>
|
||||
BoolElementIterator value_end() const {
|
||||
assert(isValidBool() && "bool is not the value of this elements attribute");
|
||||
return BoolElementIterator(*this, getNumElements());
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of APInts. The element type of
|
||||
/// this attribute must be of integer type.
|
||||
llvm::iterator_range<IntElementIterator> getIntValues() const;
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_same<T, APInt>::value>::type>
|
||||
template <typename T>
|
||||
using APIntValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, APInt>::value>::type;
|
||||
template <typename T, typename = APIntValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<IntElementIterator> getValues() const {
|
||||
return getIntValues();
|
||||
assert(getElementType().isIntOrIndex() && "expected integral type");
|
||||
return {raw_int_begin(), raw_int_end()};
|
||||
}
|
||||
template <typename T, typename = APIntValueTemplateCheckT<T>>
|
||||
IntElementIterator value_begin() const {
|
||||
assert(getElementType().isIntOrIndex() && "expected integral type");
|
||||
return raw_int_begin();
|
||||
}
|
||||
template <typename T, typename = APIntValueTemplateCheckT<T>>
|
||||
IntElementIterator value_end() const {
|
||||
assert(getElementType().isIntOrIndex() && "expected integral type");
|
||||
return raw_int_end();
|
||||
}
|
||||
IntElementIterator int_value_begin() const;
|
||||
IntElementIterator int_value_end() const;
|
||||
|
||||
/// Return the held element values as a range of complex APInts. The element
|
||||
/// type of this attribute must be a complex of integer type.
|
||||
llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_same<T, std::complex<APInt>>::value>::type>
|
||||
template <typename T>
|
||||
using ComplexAPIntValueTemplateCheckT = typename std::enable_if<
|
||||
std::is_same<T, std::complex<APInt>>::value>::type;
|
||||
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<ComplexIntElementIterator> getValues() const {
|
||||
return getComplexIntValues();
|
||||
}
|
||||
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
|
||||
ComplexIntElementIterator value_begin() const {
|
||||
return complex_value_begin();
|
||||
}
|
||||
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
|
||||
ComplexIntElementIterator value_end() const {
|
||||
return complex_value_end();
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of APFloat. The element type of
|
||||
/// this attribute must be of float type.
|
||||
llvm::iterator_range<FloatElementIterator> getFloatValues() const;
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
std::is_same<T, APFloat>::value>::type>
|
||||
template <typename T>
|
||||
using APFloatValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, APFloat>::value>::type;
|
||||
template <typename T, typename = APFloatValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<FloatElementIterator> getValues() const {
|
||||
return getFloatValues();
|
||||
}
|
||||
FloatElementIterator float_value_begin() const;
|
||||
FloatElementIterator float_value_end() const;
|
||||
template <typename T, typename = APFloatValueTemplateCheckT<T>>
|
||||
FloatElementIterator value_begin() const {
|
||||
return float_value_begin();
|
||||
}
|
||||
template <typename T, typename = APFloatValueTemplateCheckT<T>>
|
||||
FloatElementIterator value_end() const {
|
||||
return float_value_end();
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of complex APFloat. The element
|
||||
/// type of this attribute must be a complex of float type.
|
||||
llvm::iterator_range<ComplexFloatElementIterator>
|
||||
getComplexFloatValues() const;
|
||||
template <typename T, typename = typename std::enable_if<std::is_same<
|
||||
T, std::complex<APFloat>>::value>::type>
|
||||
template <typename T>
|
||||
using ComplexAPFloatValueTemplateCheckT = typename std::enable_if<
|
||||
std::is_same<T, std::complex<APFloat>>::value>::type;
|
||||
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
|
||||
return getComplexFloatValues();
|
||||
}
|
||||
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
|
||||
ComplexFloatElementIterator value_begin() const {
|
||||
return complex_float_value_begin();
|
||||
}
|
||||
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
|
||||
ComplexFloatElementIterator value_end() const {
|
||||
return complex_float_value_end();
|
||||
}
|
||||
|
||||
/// Return the raw storage data held by this attribute. Users should generally
|
||||
/// not use this directly, as the internal storage format is not always in the
|
||||
|
@ -590,13 +706,25 @@ public:
|
|||
function_ref<APInt(const APFloat &)> mapping) const;
|
||||
|
||||
protected:
|
||||
/// Get iterators to the raw APInt values for each element in this attribute.
|
||||
/// Iterators to various elements that require out-of-line definition. These
|
||||
/// are hidden from the user to encourage consistent use of the
|
||||
/// getValues/value_begin/value_end API.
|
||||
IntElementIterator raw_int_begin() const {
|
||||
return IntElementIterator(*this, 0);
|
||||
}
|
||||
IntElementIterator raw_int_end() const {
|
||||
return IntElementIterator(*this, getNumElements());
|
||||
}
|
||||
llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
|
||||
ComplexIntElementIterator complex_value_begin() const;
|
||||
ComplexIntElementIterator complex_value_end() const;
|
||||
llvm::iterator_range<FloatElementIterator> getFloatValues() const;
|
||||
FloatElementIterator float_value_begin() const;
|
||||
FloatElementIterator float_value_end() const;
|
||||
llvm::iterator_range<ComplexFloatElementIterator>
|
||||
getComplexFloatValues() const;
|
||||
ComplexFloatElementIterator complex_float_value_begin() const;
|
||||
ComplexFloatElementIterator complex_float_value_end() const;
|
||||
|
||||
/// Overload of the raw 'get' method that asserts that the given type is of
|
||||
/// complex type. This method is used to verify type invariants that the
|
||||
|
@ -616,11 +744,8 @@ protected:
|
|||
/// Check the information for a C++ data type, check if this type is valid for
|
||||
/// the current attribute. This method is used to verify specific type
|
||||
/// invariants that the templatized 'getValues' method cannot.
|
||||
bool isValidBool() const { return getElementType().isInteger(1); }
|
||||
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
|
||||
|
||||
/// Check the information for a C++ data type, check if this type is valid for
|
||||
/// the current attribute. This method is used to verify specific type
|
||||
/// invariants that the templatized 'getValues' method cannot.
|
||||
bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
|
||||
};
|
||||
|
||||
|
@ -806,7 +931,7 @@ template <typename T>
|
|||
auto SparseElementsAttr::getValues() const
|
||||
-> llvm::iterator_range<iterator<T>> {
|
||||
auto zeroValue = getZeroValue<T>();
|
||||
auto valueIt = getValues().getValues<T>().begin();
|
||||
auto valueIt = getValues().value_begin<T>();
|
||||
const std::vector<ptrdiff_t> flatSparseIndices(getFlattenedSparseIndices());
|
||||
std::function<T(ptrdiff_t)> mapFn =
|
||||
[flatSparseIndices{std::move(flatSparseIndices)},
|
||||
|
@ -821,6 +946,14 @@ auto SparseElementsAttr::getValues() const
|
|||
};
|
||||
return llvm::map_range(llvm::seq<ptrdiff_t>(0, getNumElements()), mapFn);
|
||||
}
|
||||
template <typename T>
|
||||
auto SparseElementsAttr::value_begin() const -> iterator<T> {
|
||||
return getValues<T>().begin();
|
||||
}
|
||||
template <typename T>
|
||||
auto SparseElementsAttr::value_end() const -> iterator<T> {
|
||||
return getValues<T>().end();
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
/// This class represents a general iterator over the values of an ElementsAttr.
|
||||
|
@ -833,8 +966,7 @@ class ElementsAttrIterator
|
|||
// NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype'
|
||||
// inside of a conversion operator.
|
||||
using DenseIteratorT = typename std::enable_if<
|
||||
true,
|
||||
decltype(std::declval<DenseElementsAttr>().getValues<T>().begin())>::type;
|
||||
true, decltype(std::declval<DenseElementsAttr>().value_begin<T>())>::type;
|
||||
using SparseIteratorT = SparseElementsAttr::iterator<T>;
|
||||
|
||||
/// A union containing the specific iterators for each derived attribute kind.
|
||||
|
@ -960,6 +1092,21 @@ auto ElementsAttr::getValues() const -> iterator_range<T> {
|
|||
llvm_unreachable("unexpected attribute kind");
|
||||
}
|
||||
|
||||
template <typename T> auto ElementsAttr::value_begin() const -> iterator<T> {
|
||||
if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
|
||||
return iterator<T>(*this, denseAttr.value_begin<T>());
|
||||
if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
|
||||
return iterator<T>(*this, sparseAttr.value_begin<T>());
|
||||
llvm_unreachable("unexpected attribute kind");
|
||||
}
|
||||
template <typename T> auto ElementsAttr::value_end() const -> iterator<T> {
|
||||
if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
|
||||
return iterator<T>(*this, denseAttr.value_end<T>());
|
||||
if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
|
||||
return iterator<T>(*this, sparseAttr.value_end<T>());
|
||||
llvm_unreachable("unexpected attribute kind");
|
||||
}
|
||||
|
||||
} // end namespace mlir.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -721,6 +721,8 @@ def Builtin_SparseElementsAttr
|
|||
/// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types,
|
||||
/// etc.
|
||||
template <typename T> llvm::iterator_range<iterator<T>> getValues() const;
|
||||
template <typename T> iterator<T> value_begin() const;
|
||||
template <typename T> iterator<T> value_end() const;
|
||||
|
||||
/// Return the value of the element at the given index. The 'index' is
|
||||
/// expected to refer to a valid element.
|
||||
|
|
|
@ -505,48 +505,36 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
|
|||
// Indexed accessors.
|
||||
|
||||
bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
|
||||
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
|
||||
pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<bool>(pos);
|
||||
}
|
||||
int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
|
||||
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>().begin() +
|
||||
pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int8_t>(pos);
|
||||
}
|
||||
uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
|
||||
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>().begin() +
|
||||
pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint8_t>(pos);
|
||||
}
|
||||
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
|
||||
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
|
||||
pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int32_t>(pos);
|
||||
}
|
||||
uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
|
||||
return *(
|
||||
unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
|
||||
pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint32_t>(pos);
|
||||
}
|
||||
int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
|
||||
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
|
||||
pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int64_t>(pos);
|
||||
}
|
||||
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
|
||||
return *(
|
||||
unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
|
||||
pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint64_t>(pos);
|
||||
}
|
||||
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
|
||||
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
|
||||
pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<float>(pos);
|
||||
}
|
||||
double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
|
||||
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
|
||||
pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<double>(pos);
|
||||
}
|
||||
MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
|
||||
intptr_t pos) {
|
||||
return wrap(
|
||||
*(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
|
||||
pos));
|
||||
unwrap(attr).cast<DenseElementsAttr>().getFlatValue<StringRef>(pos));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -127,7 +127,7 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
|
|||
|
||||
if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
|
||||
return failure();
|
||||
if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
|
||||
if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1),
|
||||
[](const APInt &size) { return !size.isOneValue(); }))
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -558,9 +558,9 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
|
|||
if (srcElemType != dstElemType) {
|
||||
SmallVector<Attribute, 8> elements;
|
||||
if (srcElemType.isa<FloatType>()) {
|
||||
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
|
||||
FloatAttr dstAttr = convertFloatAttr(
|
||||
srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
|
||||
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
|
||||
FloatAttr dstAttr =
|
||||
convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
|
||||
if (!dstAttr)
|
||||
return failure();
|
||||
elements.push_back(dstAttr);
|
||||
|
@ -568,10 +568,9 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
|
|||
} else if (srcElemType.isInteger(1)) {
|
||||
return failure();
|
||||
} else {
|
||||
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
|
||||
IntegerAttr dstAttr =
|
||||
convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
|
||||
dstElemType.cast<IntegerType>(), rewriter);
|
||||
for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
|
||||
IntegerAttr dstAttr = convertIntegerAttr(
|
||||
srcAttr, dstElemType.cast<IntegerType>(), rewriter);
|
||||
if (!dstAttr)
|
||||
return failure();
|
||||
elements.push_back(dstAttr);
|
||||
|
|
|
@ -1610,7 +1610,7 @@ public:
|
|||
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.resize(resultTy.getRank());
|
||||
for (auto permutation : llvm::enumerate(perms.getIntValues())) {
|
||||
for (auto permutation : llvm::enumerate(perms.getValues<APInt>())) {
|
||||
inputExprs[permutation.value().getZExtValue()] =
|
||||
rewriter.getAffineDimExpr(permutation.index());
|
||||
}
|
||||
|
|
|
@ -337,11 +337,12 @@ void gpu::addAsyncDependency(Operation *op, Value token) {
|
|||
auto attrName =
|
||||
OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr();
|
||||
auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName);
|
||||
|
||||
// Async dependencies is the only variadic operand.
|
||||
if (!sizeAttr)
|
||||
return; // Async dependencies is the only variadic operand.
|
||||
SmallVector<int32_t, 8> sizes;
|
||||
for (auto size : sizeAttr.getIntValues())
|
||||
sizes.push_back(size.getSExtValue());
|
||||
return;
|
||||
|
||||
SmallVector<int32_t, 8> sizes(sizeAttr.getValues<int32_t>());
|
||||
++sizes.front();
|
||||
op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes));
|
||||
}
|
||||
|
|
|
@ -1825,8 +1825,9 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
|
|||
// and hence was replaced.
|
||||
if (complexElementType.isa<IntegerType>()) {
|
||||
bool isSigned = !complexElementType.isUnsignedInteger();
|
||||
auto valueIt = attr.value_begin<std::complex<APInt>>();
|
||||
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
||||
auto complexValue = *(attr.getComplexIntValues().begin() + index);
|
||||
auto complexValue = *(valueIt + index);
|
||||
os << "(";
|
||||
printDenseIntElement(complexValue.real(), os, isSigned);
|
||||
os << ",";
|
||||
|
@ -1834,8 +1835,9 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
|
|||
os << ")";
|
||||
});
|
||||
} else {
|
||||
auto valueIt = attr.value_begin<std::complex<APFloat>>();
|
||||
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
||||
auto complexValue = *(attr.getComplexFloatValues().begin() + index);
|
||||
auto complexValue = *(valueIt + index);
|
||||
os << "(";
|
||||
printFloatValue(complexValue.real(), os);
|
||||
os << ",";
|
||||
|
@ -1845,15 +1847,15 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
|
|||
}
|
||||
} else if (elementType.isIntOrIndex()) {
|
||||
bool isSigned = !elementType.isUnsignedInteger();
|
||||
auto intValues = attr.getIntValues();
|
||||
auto valueIt = attr.value_begin<APInt>();
|
||||
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
||||
printDenseIntElement(*(intValues.begin() + index), os, isSigned);
|
||||
printDenseIntElement(*(valueIt + index), os, isSigned);
|
||||
});
|
||||
} else {
|
||||
assert(elementType.isa<FloatType>() && "unexpected element type");
|
||||
auto floatValues = attr.getFloatValues();
|
||||
auto valueIt = attr.value_begin<APFloat>();
|
||||
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
||||
printFloatValue(*(floatValues.begin() + index), os);
|
||||
printFloatValue(*(valueIt + index), os);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -390,6 +390,8 @@ ShapedType ElementsAttr::getType() const {
|
|||
return Attribute::getType().cast<ShapedType>();
|
||||
}
|
||||
|
||||
Type ElementsAttr::getElementType() const { return getType().getElementType(); }
|
||||
|
||||
/// Returns the number of elements held by this attribute.
|
||||
int64_t ElementsAttr::getNumElements() const {
|
||||
return getType().getNumElements();
|
||||
|
@ -635,7 +637,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
|
|||
|
||||
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
|
||||
auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
|
||||
Type eltTy = owner.getType().getElementType();
|
||||
Type eltTy = owner.getElementType();
|
||||
if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
|
||||
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
|
||||
if (eltTy.isa<IndexType>())
|
||||
|
@ -690,7 +692,7 @@ DenseElementsAttr::IntElementIterator::IntElementIterator(
|
|||
DenseElementsAttr attr, size_t dataIndex)
|
||||
: DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
|
||||
attr.getRawData().data(), attr.isSplat(), dataIndex),
|
||||
bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
|
||||
bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
|
||||
|
||||
APInt DenseElementsAttr::IntElementIterator::operator*() const {
|
||||
return readBits(getData(),
|
||||
|
@ -707,7 +709,7 @@ DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
|
|||
std::complex<APInt>, std::complex<APInt>,
|
||||
std::complex<APInt>>(
|
||||
attr.getRawData().data(), attr.isSplat(), dataIndex) {
|
||||
auto complexType = attr.getType().getElementType().cast<ComplexType>();
|
||||
auto complexType = attr.getElementType().cast<ComplexType>();
|
||||
bitWidth = getDenseElementBitWidth(complexType.getElementType());
|
||||
}
|
||||
|
||||
|
@ -930,21 +932,15 @@ DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
|
|||
isInt, isSigned);
|
||||
}
|
||||
|
||||
/// A method used to verify specific type invariants that the templatized 'get'
|
||||
/// method cannot.
|
||||
bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
|
||||
bool isSigned) const {
|
||||
return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
|
||||
isSigned);
|
||||
return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
|
||||
}
|
||||
|
||||
/// Check the information for a C++ data type, check if this type is valid for
|
||||
/// the current attribute.
|
||||
bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
|
||||
bool isSigned) const {
|
||||
return ::isValidIntOrFloat(
|
||||
getType().getElementType().cast<ComplexType>().getElementType(),
|
||||
dataEltSize / 2, isInt, isSigned);
|
||||
getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
|
||||
isInt, isSigned);
|
||||
}
|
||||
|
||||
/// Returns true if this attribute corresponds to a splat, i.e. if all element
|
||||
|
@ -953,76 +949,69 @@ bool DenseElementsAttr::isSplat() const {
|
|||
return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of Attributes.
|
||||
auto DenseElementsAttr::getAttributeValues() const
|
||||
-> llvm::iterator_range<AttributeElementIterator> {
|
||||
return {attr_value_begin(), attr_value_end()};
|
||||
}
|
||||
auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
|
||||
return AttributeElementIterator(*this, 0);
|
||||
}
|
||||
auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
|
||||
return AttributeElementIterator(*this, getNumElements());
|
||||
/// Return if the given complex type has an integer element type.
|
||||
static bool isComplexOfIntType(Type type) {
|
||||
return type.cast<ComplexType>().getElementType().isa<IntegerType>();
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of bool. The element type of
|
||||
/// this attribute must be of integer type of bitwidth 1.
|
||||
auto DenseElementsAttr::getBoolValues() const
|
||||
-> llvm::iterator_range<BoolElementIterator> {
|
||||
auto eltType = getType().getElementType().dyn_cast<IntegerType>();
|
||||
assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
|
||||
(void)eltType;
|
||||
return {BoolElementIterator(*this, 0),
|
||||
BoolElementIterator(*this, getNumElements())};
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of APInts. The element type of
|
||||
/// this attribute must be of integer type.
|
||||
auto DenseElementsAttr::getIntValues() const
|
||||
-> llvm::iterator_range<IntElementIterator> {
|
||||
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
|
||||
return {raw_int_begin(), raw_int_end()};
|
||||
}
|
||||
auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
|
||||
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
|
||||
return raw_int_begin();
|
||||
}
|
||||
auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
|
||||
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
|
||||
return raw_int_end();
|
||||
}
|
||||
auto DenseElementsAttr::getComplexIntValues() const
|
||||
-> llvm::iterator_range<ComplexIntElementIterator> {
|
||||
Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
|
||||
(void)eltTy;
|
||||
assert(eltTy.isa<IntegerType>() && "expected complex integral type");
|
||||
assert(isComplexOfIntType(getElementType()) &&
|
||||
"expected complex integral type");
|
||||
return {ComplexIntElementIterator(*this, 0),
|
||||
ComplexIntElementIterator(*this, getNumElements())};
|
||||
}
|
||||
auto DenseElementsAttr::complex_value_begin() const
|
||||
-> ComplexIntElementIterator {
|
||||
assert(isComplexOfIntType(getElementType()) &&
|
||||
"expected complex integral type");
|
||||
return ComplexIntElementIterator(*this, 0);
|
||||
}
|
||||
auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
|
||||
assert(isComplexOfIntType(getElementType()) &&
|
||||
"expected complex integral type");
|
||||
return ComplexIntElementIterator(*this, getNumElements());
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of APFloat. The element type of
|
||||
/// this attribute must be of float type.
|
||||
auto DenseElementsAttr::getFloatValues() const
|
||||
-> llvm::iterator_range<FloatElementIterator> {
|
||||
auto elementType = getType().getElementType().cast<FloatType>();
|
||||
auto elementType = getElementType().cast<FloatType>();
|
||||
const auto &elementSemantics = elementType.getFloatSemantics();
|
||||
return {FloatElementIterator(elementSemantics, raw_int_begin()),
|
||||
FloatElementIterator(elementSemantics, raw_int_end())};
|
||||
}
|
||||
auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
|
||||
return getFloatValues().begin();
|
||||
auto elementType = getElementType().cast<FloatType>();
|
||||
return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
|
||||
}
|
||||
auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
|
||||
return getFloatValues().end();
|
||||
auto elementType = getElementType().cast<FloatType>();
|
||||
return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
|
||||
}
|
||||
|
||||
auto DenseElementsAttr::getComplexFloatValues() const
|
||||
-> llvm::iterator_range<ComplexFloatElementIterator> {
|
||||
Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
|
||||
Type eltTy = getElementType().cast<ComplexType>().getElementType();
|
||||
assert(eltTy.isa<FloatType>() && "expected complex float type");
|
||||
const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
|
||||
return {{semantics, {*this, 0}},
|
||||
{semantics, {*this, static_cast<size_t>(getNumElements())}}};
|
||||
}
|
||||
auto DenseElementsAttr::complex_float_value_begin() const
|
||||
-> ComplexFloatElementIterator {
|
||||
Type eltTy = getElementType().cast<ComplexType>().getElementType();
|
||||
assert(eltTy.isa<FloatType>() && "expected complex float type");
|
||||
return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
|
||||
}
|
||||
auto DenseElementsAttr::complex_float_value_end() const
|
||||
-> ComplexFloatElementIterator {
|
||||
Type eltTy = getElementType().cast<ComplexType>().getElementType();
|
||||
assert(eltTy.isa<FloatType>() && "expected complex float type");
|
||||
return {eltTy.cast<FloatType>().getFloatSemantics(),
|
||||
{*this, static_cast<size_t>(getNumElements())}};
|
||||
}
|
||||
|
||||
/// Return the raw storage data held by this attribute.
|
||||
ArrayRef<char> DenseElementsAttr::getRawData() const {
|
||||
|
@ -1374,19 +1363,19 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
|
||||
/// Get a zero APFloat for the given sparse attribute.
|
||||
APFloat SparseElementsAttr::getZeroAPFloat() const {
|
||||
auto eltType = getType().getElementType().cast<FloatType>();
|
||||
auto eltType = getElementType().cast<FloatType>();
|
||||
return APFloat(eltType.getFloatSemantics());
|
||||
}
|
||||
|
||||
/// Get a zero APInt for the given sparse attribute.
|
||||
APInt SparseElementsAttr::getZeroAPInt() const {
|
||||
auto eltType = getType().getElementType().cast<IntegerType>();
|
||||
auto eltType = getElementType().cast<IntegerType>();
|
||||
return APInt::getZero(eltType.getWidth());
|
||||
}
|
||||
|
||||
/// Get a zero attribute for the given attribute type.
|
||||
Attribute SparseElementsAttr::getZeroAttr() const {
|
||||
auto eltType = getType().getElementType();
|
||||
auto eltType = getElementType();
|
||||
|
||||
// Handle floating point elements.
|
||||
if (eltType.isa<FloatType>())
|
||||
|
|
|
@ -1024,7 +1024,7 @@ LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op,
|
|||
return op->emitOpError("requires 1D i32 elements attribute '")
|
||||
<< attrName << "'";
|
||||
|
||||
if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) {
|
||||
if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) {
|
||||
return !element.isNonNegative();
|
||||
}))
|
||||
return op->emitOpError("'")
|
||||
|
|
|
@ -51,7 +51,7 @@ void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
|
|||
auto dattr = attr.cast<DenseIntElementsAttr>();
|
||||
res.clear();
|
||||
res.reserve(dattr.size());
|
||||
for (auto it : dattr.getIntValues())
|
||||
for (auto it : dattr.getValues<APInt>())
|
||||
res.push_back(it.getSExtValue());
|
||||
} else {
|
||||
auto vals = val.get<ShapedTypeComponents *>()->getDims();
|
||||
|
@ -71,7 +71,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
|
|||
return t.cast<ShapedType>().getDimSize(index);
|
||||
if (auto attr = val.dyn_cast<Attribute>())
|
||||
return attr.cast<DenseIntElementsAttr>()
|
||||
.getValue<APInt>({static_cast<uint64_t>(index)})
|
||||
.getFlatValue<APInt>(index)
|
||||
.getSExtValue();
|
||||
auto *stc = val.get<ShapedTypeComponents *>();
|
||||
return stc->getDims()[index];
|
||||
|
@ -94,7 +94,7 @@ bool ShapeAdaptor::hasStaticShape() const {
|
|||
return t.cast<ShapedType>().hasStaticShape();
|
||||
if (auto attr = val.dyn_cast<Attribute>()) {
|
||||
auto dattr = attr.cast<DenseIntElementsAttr>();
|
||||
for (auto index : dattr.getIntValues())
|
||||
for (auto index : dattr.getValues<APInt>())
|
||||
if (ShapedType::isDynamic(index.getSExtValue()))
|
||||
return false;
|
||||
return true;
|
||||
|
@ -115,7 +115,7 @@ int64_t ShapeAdaptor::getNumElements() const {
|
|||
if (auto attr = val.dyn_cast<Attribute>()) {
|
||||
auto dattr = attr.cast<DenseIntElementsAttr>();
|
||||
int64_t num = 1;
|
||||
for (auto index : dattr.getIntValues()) {
|
||||
for (auto index : dattr.getValues<APInt>()) {
|
||||
num *= index.getZExtValue();
|
||||
assert(num >= 0 && "integer overflow in element count computation");
|
||||
}
|
||||
|
|
|
@ -294,7 +294,8 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
|
|||
if (!nested)
|
||||
return nullptr;
|
||||
|
||||
values.append(nested.attr_value_begin(), nested.attr_value_end());
|
||||
values.append(nested.value_begin<Attribute>(),
|
||||
nested.value_end<Attribute>());
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(outerType, values);
|
||||
|
|
|
@ -83,12 +83,14 @@ const char *opSegmentSizeAttrInitCode = R"(
|
|||
auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>();
|
||||
)";
|
||||
const char *attrSizedSegmentValueRangeCalcCode = R"(
|
||||
auto sizeAttrValues = sizeAttr.getValues<uint32_t>();
|
||||
const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>();
|
||||
if (sizeAttr.isSplat())
|
||||
return {*sizeAttrValueIt * index, *sizeAttrValueIt};
|
||||
|
||||
unsigned start = 0;
|
||||
for (unsigned i = 0; i < index; ++i)
|
||||
start += *(sizeAttrValues.begin() + i);
|
||||
unsigned size = *(sizeAttrValues.begin() + index);
|
||||
return {start, size};
|
||||
start += sizeAttrValueIt[i];
|
||||
return {start, sizeAttrValueIt[index]};
|
||||
)";
|
||||
// The logic to calculate the actual value range for a declared operand
|
||||
// of an op with variadic of variadic operands within the OpAdaptor.
|
||||
|
|
|
@ -158,7 +158,7 @@ TEST(StructsGenTest, GetElements) {
|
|||
auto denseAttr = returnedAttr.dyn_cast<mlir::DenseElementsAttr>();
|
||||
ASSERT_TRUE(denseAttr);
|
||||
|
||||
for (const auto &valIndexIt : llvm::enumerate(denseAttr.getIntValues())) {
|
||||
for (const auto &valIndexIt : llvm::enumerate(denseAttr.getValues<APInt>())) {
|
||||
EXPECT_EQ(valIndexIt.value(), valIndexIt.index() + 1);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue