478 lines
14 KiB
C++
478 lines
14 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include "cute/util/print.hpp"
|
|
#include "cute/util/type_traits.hpp"
|
|
#include "cute/numeric/math.hpp"
|
|
|
|
namespace cute
|
|
{
|
|
|
|
// A constant value: short name and type-deduction for fast compilation
|
|
template <auto v>
|
|
struct C {
|
|
using type = C<v>;
|
|
static constexpr auto value = v;
|
|
using value_type = decltype(v);
|
|
CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
|
|
CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
|
|
};
|
|
|
|
// Deprecate
|
|
template <class T, T v>
|
|
using constant = C<v>;
|
|
|
|
template <bool b>
|
|
using bool_constant = C<b>;
|
|
|
|
using true_type = bool_constant<true>;
|
|
using false_type = bool_constant<false>;
|
|
|
|
// A more std:: conforming integral_constant that enforces type but interops with C<v>
|
|
template <class T, T v>
|
|
struct integral_constant : C<v> {
|
|
using type = integral_constant<T,v>;
|
|
static constexpr T value = v;
|
|
using value_type = T;
|
|
// Disambiguate C<v>::operator value_type()
|
|
//CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; }
|
|
CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; }
|
|
};
|
|
|
|
//
|
|
// Traits
|
|
//
|
|
|
|
// Use cute::is_std_integral<T> to match built-in integral types (int, int64_t, unsigned, etc)
|
|
// Use cute::is_integral<T> to match both built-in integral types AND static integral types.
|
|
|
|
template <class T>
|
|
struct is_integral : bool_constant<is_std_integral<T>::value> {};
|
|
template <auto v>
|
|
struct is_integral<C<v> > : true_type {};
|
|
template <class T, T v>
|
|
struct is_integral<integral_constant<T,v>> : true_type {};
|
|
|
|
// is_static detects if an (abstract) value is defined completely by it's type (no members)
|
|
|
|
template <class T>
|
|
struct is_static : bool_constant<is_empty<remove_cvref_t<T>>::value> {};
|
|
|
|
template <class T>
|
|
constexpr bool is_static_v = is_static<T>::value;
|
|
|
|
// is_constant detects if a type is a static integral type and if v is equal to a value
|
|
|
|
template <auto n, class T>
|
|
struct is_constant : false_type {};
|
|
template <auto n, class T>
|
|
struct is_constant<n, T const > : is_constant<n,T> {};
|
|
template <auto n, class T>
|
|
struct is_constant<n, T const&> : is_constant<n,T> {};
|
|
template <auto n, class T>
|
|
struct is_constant<n, T &> : is_constant<n,T> {};
|
|
template <auto n, class T>
|
|
struct is_constant<n, T &&> : is_constant<n,T> {};
|
|
template <auto n, auto v>
|
|
struct is_constant<n, C<v> > : bool_constant<v == n> {};
|
|
template <auto n, class T, T v>
|
|
struct is_constant<n, integral_constant<T,v>> : bool_constant<v == n> {};
|
|
|
|
//
|
|
// Specializations
|
|
//
|
|
|
|
template <int v>
|
|
using Int = C<v>;
|
|
|
|
using _m32 = Int<-32>;
|
|
using _m24 = Int<-24>;
|
|
using _m16 = Int<-16>;
|
|
using _m12 = Int<-12>;
|
|
using _m10 = Int<-10>;
|
|
using _m9 = Int<-9>;
|
|
using _m8 = Int<-8>;
|
|
using _m7 = Int<-7>;
|
|
using _m6 = Int<-6>;
|
|
using _m5 = Int<-5>;
|
|
using _m4 = Int<-4>;
|
|
using _m3 = Int<-3>;
|
|
using _m2 = Int<-2>;
|
|
using _m1 = Int<-1>;
|
|
using _0 = Int<0>;
|
|
using _1 = Int<1>;
|
|
using _2 = Int<2>;
|
|
using _3 = Int<3>;
|
|
using _4 = Int<4>;
|
|
using _5 = Int<5>;
|
|
using _6 = Int<6>;
|
|
using _7 = Int<7>;
|
|
using _8 = Int<8>;
|
|
using _9 = Int<9>;
|
|
using _10 = Int<10>;
|
|
using _12 = Int<12>;
|
|
using _16 = Int<16>;
|
|
using _24 = Int<24>;
|
|
using _32 = Int<32>;
|
|
using _64 = Int<64>;
|
|
using _96 = Int<96>;
|
|
using _128 = Int<128>;
|
|
using _192 = Int<192>;
|
|
using _256 = Int<256>;
|
|
using _384 = Int<384>;
|
|
using _512 = Int<512>;
|
|
using _768 = Int<768>;
|
|
using _1024 = Int<1024>;
|
|
using _2048 = Int<2048>;
|
|
using _4096 = Int<4096>;
|
|
using _8192 = Int<8192>;
|
|
using _16384 = Int<16384>;
|
|
using _32768 = Int<32768>;
|
|
using _65536 = Int<65536>;
|
|
using _131072 = Int<131072>;
|
|
using _262144 = Int<262144>;
|
|
using _524288 = Int<524288>;
|
|
|
|
/***************/
|
|
/** Operators **/
|
|
/***************/
|
|
|
|
#define CUTE_LEFT_UNARY_OP(OP) \
|
|
template <auto t> \
|
|
CUTE_HOST_DEVICE constexpr \
|
|
C<(OP t)> operator OP (C<t>) { \
|
|
return {}; \
|
|
}
|
|
#define CUTE_RIGHT_UNARY_OP(OP) \
|
|
template <auto t> \
|
|
CUTE_HOST_DEVICE constexpr \
|
|
C<(t OP)> operator OP (C<t>) { \
|
|
return {}; \
|
|
}
|
|
#define CUTE_BINARY_OP(OP) \
|
|
template <auto t, auto u> \
|
|
CUTE_HOST_DEVICE constexpr \
|
|
C<(t OP u)> operator OP (C<t>, C<u>) { \
|
|
return {}; \
|
|
}
|
|
|
|
CUTE_LEFT_UNARY_OP(+);
|
|
CUTE_LEFT_UNARY_OP(-);
|
|
CUTE_LEFT_UNARY_OP(~);
|
|
CUTE_LEFT_UNARY_OP(!);
|
|
CUTE_LEFT_UNARY_OP(*);
|
|
|
|
CUTE_BINARY_OP( +);
|
|
CUTE_BINARY_OP( -);
|
|
CUTE_BINARY_OP( *);
|
|
CUTE_BINARY_OP( /);
|
|
CUTE_BINARY_OP( %);
|
|
CUTE_BINARY_OP( &);
|
|
CUTE_BINARY_OP( |);
|
|
CUTE_BINARY_OP( ^);
|
|
CUTE_BINARY_OP(<<);
|
|
CUTE_BINARY_OP(>>);
|
|
|
|
CUTE_BINARY_OP(&&);
|
|
CUTE_BINARY_OP(||);
|
|
|
|
CUTE_BINARY_OP(==);
|
|
CUTE_BINARY_OP(!=);
|
|
CUTE_BINARY_OP( >);
|
|
CUTE_BINARY_OP( <);
|
|
CUTE_BINARY_OP(>=);
|
|
CUTE_BINARY_OP(<=);
|
|
|
|
#undef CUTE_BINARY_OP
|
|
#undef CUTE_LEFT_UNARY_OP
|
|
#undef CUTE_RIGHT_UNARY_OP
|
|
|
|
//
|
|
// Mixed static-dynamic special cases
|
|
//
|
|
|
|
template <auto t, class U,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && t == 0)>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<0>
|
|
operator*(C<t>, U) {
|
|
return {};
|
|
}
|
|
|
|
template <class U, auto t,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && t == 0)>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<0>
|
|
operator*(U, C<t>) {
|
|
return {};
|
|
}
|
|
|
|
template <auto t, class U,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && t == 0)>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<0>
|
|
operator/(C<t>, U) {
|
|
return {};
|
|
}
|
|
|
|
template <class U, auto t,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && (t == 1 || t == -1))>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<0>
|
|
operator%(U, C<t>) {
|
|
return {};
|
|
}
|
|
|
|
template <auto t, class U,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && t == 0)>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<0>
|
|
operator%(C<t>, U) {
|
|
return {};
|
|
}
|
|
|
|
template <auto t, class U,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && t == 0)>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<0>
|
|
operator&(C<t>, U) {
|
|
return {};
|
|
}
|
|
|
|
template <class U, auto t,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && t == 0)>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<0>
|
|
operator&(U, C<t>) {
|
|
return {};
|
|
}
|
|
|
|
template <auto t, class U,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && !bool(t))>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<false>
|
|
operator&&(C<t>, U) {
|
|
return {};
|
|
}
|
|
|
|
template <auto t, class U,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && !bool(t))>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<false>
|
|
operator&&(U, C<t>) {
|
|
return {};
|
|
}
|
|
|
|
template <class U, auto t,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && bool(t))>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<true>
|
|
operator||(C<t>, U) {
|
|
return {};
|
|
}
|
|
|
|
template <class U, auto t,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value && bool(t))>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<true>
|
|
operator||(U, C<t>) {
|
|
return {};
|
|
}
|
|
|
|
//
|
|
// Named functions from math.hpp
|
|
//
|
|
|
|
#define CUTE_NAMED_UNARY_FN(OP) \
|
|
template <auto t> \
|
|
CUTE_HOST_DEVICE constexpr \
|
|
C<OP(t)> OP (C<t>) { \
|
|
return {}; \
|
|
}
|
|
#define CUTE_NAMED_BINARY_FN(OP) \
|
|
template <auto t, auto u> \
|
|
CUTE_HOST_DEVICE constexpr \
|
|
C<OP(t,u)> OP (C<t>, C<u>) { \
|
|
return {}; \
|
|
} \
|
|
template <auto t, class U, \
|
|
__CUTE_REQUIRES(is_std_integral<U>::value)> \
|
|
CUTE_HOST_DEVICE constexpr \
|
|
auto OP (C<t>, U u) { \
|
|
return OP(t,u); \
|
|
} \
|
|
template <class T, auto u, \
|
|
__CUTE_REQUIRES(is_std_integral<T>::value)> \
|
|
CUTE_HOST_DEVICE constexpr \
|
|
auto OP (T t, C<u>) { \
|
|
return OP(t,u); \
|
|
}
|
|
|
|
CUTE_NAMED_UNARY_FN(abs);
|
|
CUTE_NAMED_UNARY_FN(signum);
|
|
CUTE_NAMED_UNARY_FN(has_single_bit);
|
|
|
|
CUTE_NAMED_BINARY_FN(max);
|
|
CUTE_NAMED_BINARY_FN(min);
|
|
CUTE_NAMED_BINARY_FN(shiftl);
|
|
CUTE_NAMED_BINARY_FN(shiftr);
|
|
CUTE_NAMED_BINARY_FN(gcd);
|
|
CUTE_NAMED_BINARY_FN(lcm);
|
|
|
|
#undef CUTE_NAMED_UNARY_FN
|
|
#undef CUTE_NAMED_BINARY_FN
|
|
|
|
//
|
|
// Other functions
|
|
//
|
|
|
|
template <auto t, auto u>
|
|
CUTE_HOST_DEVICE constexpr
|
|
C<t / u>
|
|
safe_div(C<t>, C<u>) {
|
|
static_assert(t % u == 0, "Static safe_div requires t % u == 0");
|
|
return {};
|
|
}
|
|
|
|
template <auto t, class U,
|
|
__CUTE_REQUIRES(is_std_integral<U>::value)>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
safe_div(C<t>, U u) {
|
|
return t / u;
|
|
}
|
|
|
|
template <class T, auto u,
|
|
__CUTE_REQUIRES(is_std_integral<T>::value)>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
safe_div(T t, C<u>) {
|
|
return t / u;
|
|
}
|
|
|
|
template <class TrueType, class FalseType>
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
conditional_return(true_type, TrueType&& t, FalseType&&) {
|
|
return static_cast<TrueType&&>(t);
|
|
}
|
|
|
|
template <class TrueType, class FalseType>
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
conditional_return(false_type, TrueType&&, FalseType&& f) {
|
|
return static_cast<FalseType&&>(f);
|
|
}
|
|
|
|
// TrueType and FalseType must have a common type
|
|
template <class TrueType, class FalseType>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
conditional_return(bool b, TrueType const& t, FalseType const& f) {
|
|
return b ? t : f;
|
|
}
|
|
|
|
// TrueType and FalseType don't require a common type
|
|
template <bool b, class TrueType, class FalseType>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
conditional_return(TrueType const& t, FalseType const& f) {
|
|
if constexpr (b) {
|
|
return t;
|
|
} else {
|
|
return f;
|
|
}
|
|
}
|
|
|
|
template <class Trait>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
static_value()
|
|
{
|
|
if constexpr (is_std_integral<decltype(Trait::value)>::value) {
|
|
return Int<Trait::value>{};
|
|
} else {
|
|
return Trait::value;
|
|
}
|
|
CUTE_GCC_UNREACHABLE;
|
|
}
|
|
|
|
//
|
|
// Display utilities
|
|
//
|
|
|
|
template <auto Value>
|
|
CUTE_HOST_DEVICE void print(C<Value>) {
|
|
printf("_");
|
|
::cute::print(Value);
|
|
}
|
|
|
|
#if !defined(__CUDACC_RTC__)
|
|
template <auto t>
|
|
CUTE_HOST std::ostream& operator<<(std::ostream& os, C<t> const&) {
|
|
return os << "_" << t;
|
|
}
|
|
#endif
|
|
|
|
|
|
namespace detail {
|
|
|
|
// parse_int_digits takes a variadic number of digits and converts them into an int
|
|
template <class... Ts>
|
|
constexpr uint64_t parse_int_digits(uint64_t result, int digit, Ts... digits)
|
|
{
|
|
if constexpr (sizeof...(Ts) == 0) {
|
|
return 10 * result + digit;
|
|
} else {
|
|
return parse_int_digits(10 * result + digit, digits...);
|
|
}
|
|
}
|
|
|
|
} // end namespace detail
|
|
|
|
|
|
// This user-defined literal operator allows cute::constant written as literals. For example,
|
|
//
|
|
// auto var = 32_c;
|
|
//
|
|
// var has type cute::constant<int,32>.
|
|
//
|
|
template <char... digits>
|
|
constexpr cute::constant<int,detail::parse_int_digits(0, (digits - '0')...)> operator "" _c()
|
|
{
|
|
static_assert((('0' <= digits && digits <= '9') && ...),
|
|
"Expected 0 <= digit <= 9 for each digit of the integer.");
|
|
return {};
|
|
}
|
|
|
|
} // end namespace cute
|