557 lines
15 KiB
C++
557 lines
15 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include <cute/config.hpp>
|
|
|
|
#include <cute/container/tuple.hpp>
|
|
#include <cute/numeric/integral_constant.hpp>
|
|
#include <cute/algorithm/functional.hpp>
|
|
#include <cute/algorithm/tuple_algorithms.hpp>
|
|
#include <cute/util/type_traits.hpp>
|
|
|
|
namespace cute
|
|
{
|
|
|
|
template <class... T>
|
|
struct ArithmeticTuple : tuple<T...>
|
|
{
|
|
template <class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
ArithmeticTuple(ArithmeticTuple<U...> const& u)
|
|
: tuple<T...>(static_cast<tuple<U...> const&>(u)) {}
|
|
|
|
template <class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
ArithmeticTuple(tuple<U...> const& u)
|
|
: tuple<T...>(u) {}
|
|
|
|
template <class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
ArithmeticTuple(U const&... u)
|
|
: tuple<T...>(u...) {}
|
|
};
|
|
|
|
template <class... T>
|
|
struct is_tuple<ArithmeticTuple<T...>> : true_type {};
|
|
|
|
template <class... Ts>
|
|
struct is_flat<ArithmeticTuple<Ts...>> : is_flat<tuple<Ts...>> {};
|
|
|
|
template <class... T>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
make_arithmetic_tuple(T const&... t) {
|
|
return ArithmeticTuple<T...>(t...);
|
|
}
|
|
|
|
template <class T>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
as_arithmetic_tuple(T const& t) {
|
|
if constexpr (is_tuple<T>::value) {
|
|
return detail::tapply(t, [](auto const& x){ return as_arithmetic_tuple(x); },
|
|
[](auto const&... a){ return make_arithmetic_tuple(a...); },
|
|
tuple_seq<T>{});
|
|
} else {
|
|
return t;
|
|
}
|
|
}
|
|
|
|
//
|
|
// Numeric operators
|
|
//
|
|
|
|
// Addition
|
|
template <class... T, class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator+(ArithmeticTuple<T...> const& t, ArithmeticTuple<U...> const& u) {
|
|
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
|
|
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
|
|
}
|
|
|
|
template <class... T, class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator+(ArithmeticTuple<T...> const& t, tuple<U...> const& u) {
|
|
return t + ArithmeticTuple<U...>(u);
|
|
}
|
|
|
|
template <class... T, class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator+(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
|
|
return ArithmeticTuple<T...>(t) + u;
|
|
}
|
|
|
|
// Subtraction
|
|
template <class... T, class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator-(ArithmeticTuple<T...> const& t, ArithmeticTuple<U...> const& u) {
|
|
constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U)));
|
|
return transform_apply(append<R>(t,Int<0>{}), append<R>(u,Int<0>{}), minus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
|
|
}
|
|
|
|
template <class... T, class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator-(ArithmeticTuple<T...> const& t, tuple<U...> const& u) {
|
|
return t - ArithmeticTuple<U...>(u);
|
|
}
|
|
|
|
template <class... T, class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator-(tuple<T...> const& t, ArithmeticTuple<U...> const& u) {
|
|
return ArithmeticTuple<T...>(t) - u;
|
|
}
|
|
|
|
// Negation
|
|
template <class... T>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator-(ArithmeticTuple<T...> const& t) {
|
|
return transform_apply(t, negate{}, [](auto const&... a){ return make_arithmetic_tuple(a...); });
|
|
}
|
|
|
|
//
|
|
// Special cases
|
|
//
|
|
|
|
template <auto t, class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
ArithmeticTuple<U...> const&
|
|
operator+(C<t>, ArithmeticTuple<U...> const& u) {
|
|
static_assert(t == 0, "Arithmetic tuple op+ error!");
|
|
return u;
|
|
}
|
|
|
|
template <class... T, auto u>
|
|
CUTE_HOST_DEVICE constexpr
|
|
ArithmeticTuple<T...> const&
|
|
operator+(ArithmeticTuple<T...> const& t, C<u>) {
|
|
static_assert(u == 0, "Arithmetic tuple op+ error!");
|
|
return t;
|
|
}
|
|
|
|
template <auto t, class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
ArithmeticTuple<U...> const&
|
|
operator-(C<t>, ArithmeticTuple<U...> const& u) {
|
|
static_assert(t == 0, "Arithmetic tuple op- error!");
|
|
return -u;
|
|
}
|
|
|
|
template <class... T, auto u>
|
|
CUTE_HOST_DEVICE constexpr
|
|
ArithmeticTuple<T...> const&
|
|
operator-(ArithmeticTuple<T...> const& t, C<u>) {
|
|
static_assert(u == 0, "Arithmetic tuple op- error!");
|
|
return t;
|
|
}
|
|
|
|
//
|
|
// ArithmeticTupleIterator
|
|
//
|
|
|
|
template <class ArithTuple>
|
|
struct ArithmeticTupleIterator
|
|
{
|
|
using value_type = ArithTuple;
|
|
using element_type = ArithTuple;
|
|
using reference = ArithTuple;
|
|
|
|
ArithTuple coord_;
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {}
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
ArithTuple const& operator*() const { return coord_; }
|
|
|
|
template <class Coord>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto operator[](Coord const& c) const { return *(*this + c); }
|
|
|
|
template <class Coord>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto operator+(Coord const& c) const {
|
|
return ArithmeticTupleIterator<decltype(coord_ + c)>(coord_ + c);
|
|
}
|
|
};
|
|
|
|
template <class Tuple>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
make_inttuple_iter(Tuple const& t) {
|
|
return ArithmeticTupleIterator(as_arithmetic_tuple(t));
|
|
}
|
|
|
|
template <class T0, class T1, class... Ts>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) {
|
|
return make_inttuple_iter(cute::make_tuple(t0, t1, ts...));
|
|
}
|
|
|
|
//
|
|
// ArithmeticTuple "basis" elements
|
|
// A ScaledBasis<T,N> is a (at least) rank-N+1 ArithmeticTuple:
|
|
// (_0,_0,...,T,_0,...)
|
|
// with value T in the Nth mode
|
|
|
|
template <class T, int N>
|
|
struct ScaledBasis : private tuple<T>
|
|
{
|
|
CUTE_HOST_DEVICE constexpr
|
|
ScaledBasis(T const& t = {}) : tuple<T>(t) {}
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto) value() { return get<0>(static_cast<tuple<T> &>(*this)); }
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto) value() const { return get<0>(static_cast<tuple<T> const&>(*this)); }
|
|
|
|
CUTE_HOST_DEVICE static constexpr
|
|
auto mode() { return Int<N>{}; }
|
|
};
|
|
|
|
template <class T>
|
|
struct is_scaled_basis : false_type {};
|
|
template <class T, int N>
|
|
struct is_scaled_basis<ScaledBasis<T,N>> : true_type {};
|
|
|
|
template <class T, int N>
|
|
struct is_integral<ScaledBasis<T,N>> : true_type {};
|
|
|
|
// Get the scalar T out of a ScaledBasis
|
|
template <class SB>
|
|
CUTE_HOST_DEVICE constexpr auto
|
|
basis_value(SB const& e)
|
|
{
|
|
if constexpr (is_scaled_basis<SB>::value) {
|
|
return basis_value(e.value());
|
|
} else {
|
|
return e;
|
|
}
|
|
CUTE_GCC_UNREACHABLE;
|
|
}
|
|
|
|
// Apply the N... pack to another Tuple
|
|
template <class SB, class Tuple>
|
|
CUTE_HOST_DEVICE constexpr auto
|
|
basis_get(SB const& e, Tuple const& t)
|
|
{
|
|
if constexpr (is_scaled_basis<SB>::value) {
|
|
return basis_get(e.value(), get<SB::mode()>(t));
|
|
} else {
|
|
return t;
|
|
}
|
|
CUTE_GCC_UNREACHABLE;
|
|
}
|
|
|
|
namespace detail {
|
|
|
|
template <class T, int... I>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
to_atuple_i(T const& t, seq<I...>) {
|
|
return make_arithmetic_tuple((void(I),Int<0>{})..., t);
|
|
}
|
|
|
|
} // end namespace detail
|
|
|
|
// Turn a ScaledBases<T,N> into a rank-N+1 ArithmeticTuple
|
|
// with N prefix 0s: (_0,_0,...N...,_0,T)
|
|
template <class T, int N>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
as_arithmetic_tuple(ScaledBasis<T,N> const& t) {
|
|
return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq<N>{});
|
|
}
|
|
|
|
namespace detail {
|
|
|
|
template <int... Ns>
|
|
struct Basis;
|
|
|
|
template <>
|
|
struct Basis<> {
|
|
using type = Int<1>;
|
|
};
|
|
|
|
template <int N, int... Ns>
|
|
struct Basis<N,Ns...> {
|
|
using type = ScaledBasis<typename Basis<Ns...>::type, N>;
|
|
};
|
|
|
|
} // end namespace detail
|
|
|
|
// Shortcut for writing ScaledBasis<ScaledBasis<ScaledBasis<Int<1>, N0>, N1>, ...>
|
|
// E<> := _1
|
|
// E<0> := (_1,_0,_0,...)
|
|
// E<1> := (_0,_1,_0,...)
|
|
// E<0,0> := ((_1,_0,_0,...),_0,_0,...)
|
|
// E<0,1> := ((_0,_1,_0,...),_0,_0,...)
|
|
// E<1,0> := (_0,(_1,_0,_0,...),_0,...)
|
|
// E<1,1> := (_0,(_0,_1,_0,...),_0,...)
|
|
template <int... N>
|
|
using E = typename detail::Basis<N...>::type;
|
|
|
|
template <class Shape>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
make_basis_like(Shape const& shape)
|
|
{
|
|
if constexpr (is_integral<Shape>::value) {
|
|
return Int<1>{};
|
|
} else {
|
|
// Generate bases for each rank of shape
|
|
return transform(tuple_seq<Shape>{}, shape, [](auto I, auto si) {
|
|
// Generate bases for each rank of si and add an i on front
|
|
using I_type = decltype(I);
|
|
return transform_leaf(make_basis_like(si), [](auto e) {
|
|
// MSVC has trouble capturing variables as constexpr,
|
|
// so that they can be used as template arguments.
|
|
// This is exactly what the code needs to do with i, unfortunately.
|
|
// The work-around is to define i inside the inner lambda,
|
|
// by using just the type from the enclosing scope.
|
|
constexpr int i = I_type::value;
|
|
return ScaledBasis<decltype(e), i>{};
|
|
});
|
|
});
|
|
}
|
|
|
|
CUTE_GCC_UNREACHABLE;
|
|
}
|
|
|
|
//
|
|
// Arithmetic
|
|
//
|
|
|
|
template <class T, int M, class U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
safe_div(ScaledBasis<T,M> const& b, U const& u)
|
|
{
|
|
auto t = safe_div(b.value(), u);
|
|
return ScaledBasis<decltype(t),M>{t};
|
|
}
|
|
|
|
template <class T, int M, class U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
shape_div(ScaledBasis<T,M> const& b, U const& u)
|
|
{
|
|
auto t = shape_div(b.value(), u);
|
|
return ScaledBasis<decltype(t),M>{t};
|
|
}
|
|
|
|
// Equality
|
|
template <class T, int N, class U, int M>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator==(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
|
|
return bool_constant<M == N>{} && t.value() == u.value();
|
|
}
|
|
|
|
// Not equal to anything else
|
|
template <class T, int N, class U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
false_type
|
|
operator==(ScaledBasis<T,N> const&, U const&) {
|
|
return {};
|
|
}
|
|
|
|
template <class T, class U, int M>
|
|
CUTE_HOST_DEVICE constexpr
|
|
false_type
|
|
operator==(T const&, ScaledBasis<U,M> const&) {
|
|
return {};
|
|
}
|
|
|
|
// Abs
|
|
template <class T, int N>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
abs(ScaledBasis<T,N> const& e) {
|
|
return ScaledBasis<decltype(abs(e.value())),N>{abs(e.value())};
|
|
}
|
|
|
|
// Multiplication
|
|
template <class A, class T, int N>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator*(A const& a, ScaledBasis<T,N> const& e) {
|
|
auto r = a * e.value();
|
|
return ScaledBasis<decltype(r),N>{r};
|
|
}
|
|
|
|
template <class T, int N, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator*(ScaledBasis<T,N> const& e, B const& b) {
|
|
auto r = e.value() * b;
|
|
return ScaledBasis<decltype(r),N>{r};
|
|
}
|
|
|
|
// Addition
|
|
template <class T, int N, class U, int M>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator+(ScaledBasis<T,N> const& t, ScaledBasis<U,M> const& u) {
|
|
return as_arithmetic_tuple(t) + as_arithmetic_tuple(u);
|
|
}
|
|
|
|
template <class T, int N, class... U>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator+(ScaledBasis<T,N> const& t, ArithmeticTuple<U...> const& u) {
|
|
return as_arithmetic_tuple(t) + u;
|
|
}
|
|
|
|
template <class... T, class U, int M>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator+(ArithmeticTuple<T...> const& t, ScaledBasis<U,M> const& u) {
|
|
return t + as_arithmetic_tuple(u);
|
|
}
|
|
|
|
template <auto t, class U, int M>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator+(C<t>, ScaledBasis<U,M> const& u) {
|
|
static_assert(t == 0, "ScaledBasis op+ error!");
|
|
return u;
|
|
}
|
|
|
|
template <class T, int N, auto u>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator+(ScaledBasis<T,N> const& t, C<u>) {
|
|
static_assert(u == 0, "ScaledBasis op+ error!");
|
|
return t;
|
|
}
|
|
|
|
//
|
|
// Display utilities
|
|
//
|
|
|
|
template <class ArithTuple>
|
|
CUTE_HOST_DEVICE void print(ArithmeticTupleIterator<ArithTuple> const& iter)
|
|
{
|
|
printf("ArithTuple"); print(iter.coord_);
|
|
}
|
|
|
|
template <class T, int N>
|
|
CUTE_HOST_DEVICE void print(ScaledBasis<T,N> const& e)
|
|
{
|
|
print(e.value()); printf("@%d", N);
|
|
}
|
|
|
|
#if !defined(__CUDACC_RTC__)
|
|
template <class ArithTuple>
|
|
CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator<ArithTuple> const& iter)
|
|
{
|
|
return os << "ArithTuple" << iter.coord_;
|
|
}
|
|
|
|
template <class T, int N>
|
|
CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis<T,N> const& e)
|
|
{
|
|
return os << e.value() << "@" << N;
|
|
}
|
|
#endif
|
|
|
|
} // end namespace cute
|
|
|
|
|
|
namespace CUTE_STL_NAMESPACE
|
|
{
|
|
|
|
template <class... T>
|
|
struct tuple_size<cute::ArithmeticTuple<T...>>
|
|
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
|
{};
|
|
|
|
template <size_t I, class... T>
|
|
struct tuple_element<I, cute::ArithmeticTuple<T...>>
|
|
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
|
{};
|
|
|
|
template <class... T>
|
|
struct tuple_size<const cute::ArithmeticTuple<T...>>
|
|
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
|
{};
|
|
|
|
template <size_t I, class... T>
|
|
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
|
|
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
|
|
{};
|
|
|
|
} // end namespace CUTE_STL_NAMESPACE
|
|
|
|
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
|
|
namespace std
|
|
{
|
|
|
|
#if defined(__CUDACC_RTC__)
|
|
template <class... _Tp>
|
|
struct tuple_size;
|
|
|
|
template <size_t _Ip, class... _Tp>
|
|
struct tuple_element;
|
|
#endif
|
|
|
|
template <class... T>
|
|
struct tuple_size<cute::ArithmeticTuple<T...>>
|
|
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
|
{};
|
|
|
|
template <size_t I, class... T>
|
|
struct tuple_element<I, cute::ArithmeticTuple<T...>>
|
|
: CUTE_STL_NAMESPACE::tuple_element<I, CUTE_STL_NAMESPACE::tuple<T...>>
|
|
{};
|
|
|
|
template <class... T>
|
|
struct tuple_size<const cute::ArithmeticTuple<T...>>
|
|
: CUTE_STL_NAMESPACE::integral_constant<size_t, sizeof...(T)>
|
|
{};
|
|
|
|
template <size_t I, class... T>
|
|
struct tuple_element<I, const cute::ArithmeticTuple<T...>>
|
|
: CUTE_STL_NAMESPACE::tuple_element<I, const CUTE_STL_NAMESPACE::tuple<T...>>
|
|
{};
|
|
|
|
} // end namespace std
|
|
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
|