1403 lines
52 KiB
C++
1403 lines
52 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
#include <cute/config.hpp>
|
|
|
|
#include <cute/arch/mma.hpp>
|
|
|
|
// Config
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
|
# define CUTE_ARCH_MMA_SM90_ENABLED
|
|
# define CUTE_ARCH_MMA_F64_SM90_ENABLED
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cute {
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// MMA 16x8x4 TN
|
|
struct SM90_16x8x4_F64F64F64F64_TN
|
|
{
|
|
using DRegisters = double[4];
|
|
using ARegisters = double[2];
|
|
using BRegisters = double[1];
|
|
using CRegisters = double[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(double & d0, double & d1, double & d2, double & d3,
|
|
double const& a0, double const& a1,
|
|
double const& b0,
|
|
double const& c0, double const& c1, double const& c2, double const& c3)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED)
|
|
asm volatile(
|
|
"mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64"
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5},"
|
|
"{%6},"
|
|
"{%7, %8, %9, %10};\n"
|
|
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
|
|
: "d"(a0), "d"(a1),
|
|
"d"(b0),
|
|
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// MMA 16x8x8 TN
|
|
struct SM90_16x8x8_F64F64F64F64_TN
|
|
{
|
|
using DRegisters = double[4];
|
|
using ARegisters = double[4];
|
|
using BRegisters = double[2];
|
|
using CRegisters = double[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(double & d0, double & d1, double & d2, double & d3,
|
|
double const& a0, double const& a1, double const& a2, double const& a3,
|
|
double const& b0, double const& b1,
|
|
double const& c0, double const& c1, double const& c2, double const& c3)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED)
|
|
asm volatile(
|
|
"mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64"
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
"{%8, %9},"
|
|
"{%10, %11, %12, %13};\n"
|
|
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
|
|
: "d"(a0), "d"(a1), "d"(a2), "d"(a3),
|
|
"d"(b0), "d"(b1),
|
|
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// MMA 16x8x16 TN
|
|
struct SM90_16x8x16_F64F64F64F64_TN
|
|
{
|
|
using DRegisters = double[4];
|
|
using ARegisters = double[8];
|
|
using BRegisters = double[4];
|
|
using CRegisters = double[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(double & d0, double & d1, double & d2, double & d3,
|
|
double const& a0, double const& a1, double const& a2, double const& a3,
|
|
double const& a4, double const& a5, double const& a6, double const& a7,
|
|
double const& b0, double const& b1, double const& b2, double const& b3,
|
|
double const& c0, double const& c1, double const& c2, double const& c3)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED)
|
|
asm volatile(
|
|
"mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64"
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7, %8, %9, %10, %11},"
|
|
"{%12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19};\n"
|
|
: "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3)
|
|
: "d"(a0), "d"(a1), "d"(a2), "d"(a3),
|
|
"d"(a4), "d"(a5), "d"(a6), "d"(a7),
|
|
"d"(b0), "d"(b1), "d"(b2), "d"(b3),
|
|
"d"(c0), "d"(c1), "d"(c2), "d"(c3));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// MMA 16x8x4 TN
|
|
struct SM90_16x8x4_C64C64C64C64_TN
|
|
{
|
|
using DRegisters = complex<double>[4];
|
|
using ARegisters = complex<double>[2];
|
|
using BRegisters = complex<double>[1];
|
|
using CRegisters = complex<double>[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(complex<double> & d0, complex<double> & d1,
|
|
complex<double> & d2, complex<double> & d3,
|
|
complex<double> const& a0, complex<double> const& a1,
|
|
complex<double> const& b0,
|
|
complex<double> const& c0, complex<double> const& c1,
|
|
complex<double> const& c2, complex<double> const& c3)
|
|
{
|
|
// Because thrust::complex does not provide a mutable ref
|
|
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
|
|
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
|
|
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
|
|
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
|
|
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
|
|
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
|
|
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
|
|
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
|
|
|
|
// d.real() = a.real() * b.real() + c.real();
|
|
SM90_16x8x4_F64F64F64F64_TN::fma(
|
|
rd0, rd1, rd2, rd3,
|
|
a0.real(), a1.real(),
|
|
b0.real(),
|
|
c0.real(), c1.real(), c2.real(), c3.real());
|
|
|
|
// d.imag() = a.imag() * b.real() + c.imag();
|
|
SM90_16x8x4_F64F64F64F64_TN::fma(
|
|
id0, id1, id2, id3,
|
|
a0.imag(), a1.imag(),
|
|
b0.real(),
|
|
c0.imag(), c1.imag(), c2.imag(), c3.imag());
|
|
|
|
// d.real() = -a.imag() * b.imag() + d.real();
|
|
SM90_16x8x4_F64F64F64F64_TN::fma(
|
|
rd0, rd1, rd2, rd3,
|
|
-a0.imag(), -a1.imag(),
|
|
b0.imag(),
|
|
d0.real(), d1.real(), d2.real(), d3.real());
|
|
|
|
// d.imag() = a.real() * b.imag() + d.imag();
|
|
SM90_16x8x4_F64F64F64F64_TN::fma(
|
|
id0, id1, id2, id3,
|
|
a0.real(), a1.real(),
|
|
b0.imag(),
|
|
d0.imag(), d1.imag(), d2.imag(), d3.imag());
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// MMA 16x8x8 TN
|
|
struct SM90_16x8x8_C64C64C64C64_TN
|
|
{
|
|
using DRegisters = complex<double>[4];
|
|
using ARegisters = complex<double>[4];
|
|
using BRegisters = complex<double>[2];
|
|
using CRegisters = complex<double>[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(complex<double> & d0, complex<double> & d1,
|
|
complex<double> & d2, complex<double> & d3,
|
|
complex<double> const& a0, complex<double> const& a1,
|
|
complex<double> const& a2, complex<double> const& a3,
|
|
complex<double> const& b0, complex<double> const& b1,
|
|
complex<double> const& c0, complex<double> const& c1,
|
|
complex<double> const& c2, complex<double> const& c3)
|
|
{
|
|
// Because thrust::complex does not provide a mutable ref
|
|
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
|
|
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
|
|
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
|
|
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
|
|
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
|
|
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
|
|
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
|
|
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
|
|
|
|
// d.real() = a.real() * b.real() + c.real();
|
|
SM90_16x8x8_F64F64F64F64_TN::fma(
|
|
rd0, rd1, rd2, rd3,
|
|
a0.real(), a1.real(), a2.real(), a3.real(),
|
|
b0.real(), b1.real(),
|
|
c0.real(), c1.real(), c2.real(), c3.real());
|
|
|
|
// d.imag() = a.imag() * b.real() + c.imag();
|
|
SM90_16x8x8_F64F64F64F64_TN::fma(
|
|
id0, id1, id2, id3,
|
|
a0.imag(), a1.imag(), a2.imag(), a3.imag(),
|
|
b0.real(), b1.real(),
|
|
c0.imag(), c1.imag(), c2.imag(), c3.imag());
|
|
|
|
// d.real() = -a.imag() * b.imag() + d.real();
|
|
SM90_16x8x8_F64F64F64F64_TN::fma(
|
|
rd0, rd1, rd2, rd3,
|
|
-a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(),
|
|
b0.imag(), b1.imag(),
|
|
d0.real(), d1.real(), d2.real(), d3.real());
|
|
|
|
// d.imag() = a.real() * b.imag() + d.imag();
|
|
SM90_16x8x8_F64F64F64F64_TN::fma(
|
|
id0, id1, id2, id3,
|
|
a0.real(), a1.real(), a2.real(), a3.real(),
|
|
b0.imag(), b1.imag(),
|
|
d0.imag(), d1.imag(), d2.imag(), d3.imag());
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// MMA 16x8x16 TN
|
|
struct SM90_16x8x16_C64C64C64C64_TN
|
|
{
|
|
using DRegisters = complex<double>[4];
|
|
using ARegisters = complex<double>[8];
|
|
using BRegisters = complex<double>[4];
|
|
using CRegisters = complex<double>[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(complex<double> & d0, complex<double> & d1,
|
|
complex<double> & d2, complex<double> & d3,
|
|
complex<double> const& a0, complex<double> const& a1,
|
|
complex<double> const& a2, complex<double> const& a3,
|
|
complex<double> const& a4, complex<double> const& a5,
|
|
complex<double> const& a6, complex<double> const& a7,
|
|
complex<double> const& b0, complex<double> const& b1,
|
|
complex<double> const& b2, complex<double> const& b3,
|
|
complex<double> const& c0, complex<double> const& c1,
|
|
complex<double> const& c2, complex<double> const& c3)
|
|
{
|
|
// Because thrust::complex does not provide a mutable ref
|
|
double& rd0 = reinterpret_cast<double(&)[2]>(d0)[0];
|
|
double& id0 = reinterpret_cast<double(&)[2]>(d0)[1];
|
|
double& rd1 = reinterpret_cast<double(&)[2]>(d1)[0];
|
|
double& id1 = reinterpret_cast<double(&)[2]>(d1)[1];
|
|
double& rd2 = reinterpret_cast<double(&)[2]>(d2)[0];
|
|
double& id2 = reinterpret_cast<double(&)[2]>(d2)[1];
|
|
double& rd3 = reinterpret_cast<double(&)[2]>(d3)[0];
|
|
double& id3 = reinterpret_cast<double(&)[2]>(d3)[1];
|
|
|
|
// d.real() = a.real() * b.real() + c.real();
|
|
SM90_16x8x16_F64F64F64F64_TN::fma(
|
|
rd0, rd1, rd2, rd3,
|
|
a0.real(), a1.real(), a2.real(), a3.real(),
|
|
a4.real(), a5.real(), a6.real(), a7.real(),
|
|
b0.real(), b1.real(), b2.real(), b3.real(),
|
|
c0.real(), c1.real(), c2.real(), c3.real());
|
|
|
|
// d.imag() = a.imag() * b.real() + c.imag();
|
|
SM90_16x8x16_F64F64F64F64_TN::fma(
|
|
id0, id1, id2, id3,
|
|
a0.imag(), a1.imag(), a2.imag(), a3.imag(),
|
|
a4.imag(), a5.imag(), a6.imag(), a7.imag(),
|
|
b0.real(), b1.real(), b2.real(), b3.real(),
|
|
c0.imag(), c1.imag(), c2.imag(), c3.imag());
|
|
|
|
// d.real() = -a.imag() * b.imag() + d.real();
|
|
SM90_16x8x16_F64F64F64F64_TN::fma(
|
|
rd0, rd1, rd2, rd3,
|
|
-a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(),
|
|
-a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(),
|
|
b0.imag(), b1.imag(), b2.imag(), b3.imag(),
|
|
d0.real(), d1.real(), d2.real(), d3.real());
|
|
|
|
// d.imag() = a.real() * b.imag() + d.imag();
|
|
SM90_16x8x16_F64F64F64F64_TN::fma(
|
|
id0, id1, id2, id3,
|
|
a0.real(), a1.real(), a2.real(), a3.real(),
|
|
a4.real(), a5.real(), a6.real(), a7.real(),
|
|
b0.imag(), b1.imag(), b2.imag(), b3.imag(),
|
|
d0.imag(), d1.imag(), d2.imag(), d3.imag());
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace cute
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#include <cute/arch/mma_sm90_desc.hpp>
|
|
#include <cute/arch/mma_sm90_gmma.hpp>
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cute {
|
|
namespace GMMA {
|
|
|
|
template <
|
|
class ElementA,
|
|
class ElementB,
|
|
class ElementC,
|
|
class TileShape_MNK,
|
|
GMMA::Major MajorA = GMMA::Major::K,
|
|
GMMA::Major MajorB = GMMA::Major::K,
|
|
auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One]
|
|
// But most commonly leave empty for defaults
|
|
>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
ss_op_selector()
|
|
{
|
|
static_assert(is_static<TileShape_MNK>::value, "TileShape_MNK must be static.");
|
|
static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3.");
|
|
static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64.");
|
|
auto Tile_N = size<1>(TileShape_MNK{});
|
|
|
|
// FP16 accumulator
|
|
if constexpr (is_same_v<ElementC, half_t>) {
|
|
if constexpr (is_same_v<ElementA, half_t> && is_same_v<ElementB, half_t>) {
|
|
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
|
|
|
// Dispatch against the Tile N mode size
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x16_F16F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e4m3_t ; Input B: float_e4m3_t
|
|
else if constexpr (is_same_v<ElementA, float_e4m3_t> && is_same_v<ElementB, float_e4m3_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F16E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F16E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F16E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F16E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F16E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F16E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F16E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F16E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e4m3_t ; Input B: float_e5m2_t
|
|
else if constexpr (is_same_v<ElementA, float_e4m3_t> && is_same_v<ElementB, float_e5m2_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F16E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F16E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F16E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F16E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F16E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F16E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F16E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F16E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e5m2_t ; Input B: float_e5m2_t
|
|
else if constexpr (is_same_v<ElementA, float_e5m2_t> && is_same_v<ElementB, float_e5m2_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F16E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F16E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F16E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F16E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F16E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F16E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F16E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F16E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e5m2_t ; Input B: float_e4m3_t
|
|
else if constexpr (is_same_v<ElementA, float_e5m2_t> && is_same_v<ElementB, float_e4m3_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F16E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F16E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F16E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F16E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F16E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F16E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F16E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F16E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
else {
|
|
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
|
|
}
|
|
}
|
|
|
|
// FP32 accumulator
|
|
else if constexpr (is_same_v<ElementC, float>) {
|
|
|
|
// FP16 inputs
|
|
if constexpr (is_same_v<ElementA, half_t>) {
|
|
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x16_F32F16F16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// BF16 inputs
|
|
else if constexpr (is_same_v<ElementA, bfloat16_t>) {
|
|
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x16_F32BF16BF16_SS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// TF32 inputs
|
|
else if constexpr (is_same_v<ElementA, tfloat32_t>) {
|
|
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x8_F32TF32TF32_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x8_F32TF32TF32_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x8_F32TF32TF32_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x8_F32TF32TF32_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x8_F32TF32TF32_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x8_F32TF32TF32_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x8_F32TF32TF32_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x8_F32TF32TF32_SS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e4m3_t ; Input B: float_e4m3_t
|
|
else if constexpr (is_same_v<ElementA, float_e4m3_t> && is_same_v<ElementB, float_e4m3_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F32E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F32E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F32E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F32E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F32E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F32E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F32E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F32E4M3E4M3_SS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e4m3_t ; Input B: float_e5m2_t
|
|
else if constexpr (is_same_v<ElementA, float_e4m3_t> && is_same_v<ElementB, float_e5m2_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F32E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F32E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F32E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F32E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F32E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F32E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F32E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F32E4M3E5M2_SS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e5m2_t ; Input B: float_e5m2_t
|
|
else if constexpr (is_same_v<ElementA, float_e5m2_t> && is_same_v<ElementB, float_e5m2_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F32E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F32E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F32E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F32E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F32E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F32E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F32E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F32E5M2E5M2_SS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e5m2_t ; Input B: float_e4m3_t
|
|
else if constexpr (is_same_v<ElementA, float_e5m2_t> && is_same_v<ElementB, float_e4m3_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F32E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F32E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F32E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F32E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F32E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F32E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F32E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F32E5M2E4M3_SS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
else {
|
|
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
|
|
}
|
|
}
|
|
|
|
// S32 accumulator
|
|
else if constexpr (is_same_v<ElementC, int32_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
// ElementA == int8_t && ElementB == int8_t
|
|
if constexpr (is_same_v<ElementA, int8_t> && is_same_v<ElementB, int8_t>) {
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_S32S8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_S32S8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_S32S8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_S32S8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_S32S8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_S32S8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_S32S8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_S32S8S8_SS_TN{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// ElementA == int8_t && ElementB == uint8_t
|
|
else if constexpr (is_same_v<ElementA, int8_t> && is_same_v<ElementB, uint8_t>) {
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_S32S8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_S32S8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_S32S8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_S32S8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_S32S8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_S32S8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_S32S8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_S32S8U8_SS_TN{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// ElementA == uint8_t && ElementB == int8_t
|
|
else if constexpr (is_same_v<ElementA, uint8_t> && is_same_v<ElementB, int8_t>) {
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_S32U8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_S32U8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_S32U8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_S32U8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_S32U8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_S32U8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_S32U8S8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_S32U8S8_SS_TN{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// ElementA == uint8_t && ElementB == uint8_t
|
|
else if constexpr (is_same_v<ElementA, uint8_t> && is_same_v<ElementB, uint8_t>) {
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_S32U8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_S32U8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_S32U8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_S32U8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_S32U8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_S32U8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_S32U8U8_SS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_S32U8U8_SS_TN{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Unknown accumulator type
|
|
else {
|
|
static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type.");
|
|
}
|
|
}
|
|
|
|
template <
|
|
class ElementA,
|
|
class ElementB,
|
|
class ElementC,
|
|
class TileShape_MNK,
|
|
GMMA::Major MajorA = GMMA::Major::K,
|
|
GMMA::Major MajorB = GMMA::Major::K,
|
|
auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One]
|
|
// But most commonly leave empty for defaults
|
|
>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
rs_op_selector()
|
|
{
|
|
static_assert(is_static<TileShape_MNK>::value, "TileShape_MNK must be static.");
|
|
static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3.");
|
|
static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64.");
|
|
static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout.");
|
|
auto Tile_N = size<1>(TileShape_MNK{});
|
|
|
|
// FP16 accumulator
|
|
if constexpr (is_same_v<ElementC, half_t>) {
|
|
static_assert(is_same_v<ElementA, half_t>, "Element types for AB must be half if ElementC is half.");
|
|
static_assert(is_same_v<ElementB, half_t>, "Element types for AB must be half if ElementC is half.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
|
|
|
// Dispatch against the Tile N mode size
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x16_F16F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP32 accumulator
|
|
else if constexpr (is_same_v<ElementC, float>) {
|
|
|
|
// FP16 inputs
|
|
if constexpr (is_same_v<ElementA, half_t>) {
|
|
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
|
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x16_F32F16F16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// BF16 inputs
|
|
else if constexpr (is_same_v<ElementA, bfloat16_t>) {
|
|
static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16.");
|
|
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x16_F32BF16BF16_RS<MajorA, MajorB, Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// TF32 inputs
|
|
else if constexpr (is_same_v<ElementA, tfloat32_t>) {
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8.");
|
|
static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x8_F32TF32TF32_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x8_F32TF32TF32_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x8_F32TF32TF32_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x8_F32TF32TF32_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x8_F32TF32TF32_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x8_F32TF32TF32_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x8_F32TF32TF32_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x8_F32TF32TF32_RS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e4m3_t ; Input B: float_e4m3_t
|
|
else if constexpr (is_same_v<ElementA, float_e4m3_t> && is_same_v<ElementB, float_e4m3_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F32E4M3E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F32E4M3E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F32E4M3E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F32E4M3E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F32E4M3E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F32E4M3E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F32E4M3E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F32E4M3E4M3_RS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e4m3_t ; Input B: float_e5m2_t
|
|
else if constexpr (is_same_v<ElementA, float_e4m3_t> && is_same_v<ElementB, float_e5m2_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F32E4M3E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F32E4M3E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F32E4M3E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F32E4M3E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F32E4M3E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F32E4M3E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F32E4M3E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F32E4M3E5M2_RS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e5m2_t ; Input B: float_e5m2_t
|
|
else if constexpr (is_same_v<ElementA, float_e5m2_t> && is_same_v<ElementB, float_e5m2_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F32E5M2E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F32E5M2E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F32E5M2E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F32E5M2E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F32E5M2E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F32E5M2E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F32E5M2E5M2_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F32E5M2E5M2_RS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// FP8
|
|
// Input A: float_e5m2_t ; Input B: float_e4m3_t
|
|
else if constexpr (is_same_v<ElementA, float_e5m2_t> && is_same_v<ElementB, float_e4m3_t>) {
|
|
static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config.");
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_F32E5M2E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_F32E5M2E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_F32E5M2E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_F32E5M2E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_F32E5M2E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_F32E5M2E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_F32E5M2E4M3_RS_TN<Args...>{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_F32E5M2E4M3_RS_TN<Args...>{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
else {
|
|
static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration.");
|
|
}
|
|
}
|
|
|
|
// S32 accumulator
|
|
else if constexpr (is_same_v<ElementC, int32_t>) {
|
|
static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config.");
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
// ElementA == int8_t && ElementB == int8_t
|
|
if constexpr (is_same_v<ElementA, int8_t> && is_same_v<ElementB, int8_t>) {
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_S32S8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_S32S8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_S32S8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_S32S8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_S32S8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_S32S8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_S32S8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_S32S8S8_RS_TN{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// ElementA == int8_t && ElementB == uint8_t
|
|
else if constexpr (is_same_v<ElementA, int8_t> && is_same_v<ElementB, uint8_t>) {
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_S32S8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_S32S8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_S32S8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_S32S8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_S32S8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_S32S8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_S32S8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_S32S8U8_RS_TN{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// ElementA == uint8_t && ElementB == int8_t
|
|
else if constexpr (is_same_v<ElementA, uint8_t> && is_same_v<ElementB, int8_t>) {
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_S32U8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_S32U8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_S32U8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_S32U8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_S32U8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_S32U8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_S32U8S8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_S32U8S8_RS_TN{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
|
|
// ElementA == uint8_t && ElementB == uint8_t
|
|
else if constexpr (is_same_v<ElementA, uint8_t> && is_same_v<ElementB, uint8_t>) {
|
|
static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32.");
|
|
|
|
if constexpr (Tile_N % 256 == 0) {
|
|
return SM90_64x256x32_S32U8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 192 == 0) {
|
|
return SM90_64x192x32_S32U8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 128 == 0) {
|
|
return SM90_64x128x32_S32U8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 96 == 0) {
|
|
return SM90_64x96x32_S32U8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 64 == 0) {
|
|
return SM90_64x64x32_S32U8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 32 == 0) {
|
|
return SM90_64x32x32_S32U8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 16 == 0) {
|
|
return SM90_64x16x32_S32U8U8_RS_TN{};
|
|
}
|
|
else if constexpr (Tile_N % 8 == 0) {
|
|
return SM90_64x8x32_S32U8U8_RS_TN{};
|
|
}
|
|
else {
|
|
static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8.");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Unknown accumulator type
|
|
else {
|
|
static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type.");
|
|
}
|
|
}
|
|
} // end namespace GMMA
|
|
} // end namespace cute
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|