20640 lines
925 KiB
C++
20640 lines
925 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include <cute/config.hpp>
|
|
#include <cute/arch/mma.hpp>
|
|
// Config
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL))
|
|
# define CUTE_ARCH_MMA_SM90A_ENABLED
|
|
#endif
|
|
|
|
namespace cute {
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Warpgroup sync primitives
|
|
|
|
CUTE_HOST_DEVICE
|
|
void
|
|
warpgroup_arrive()
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory");
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
|
|
template <int N>
|
|
CUTE_HOST_DEVICE
|
|
void
|
|
warpgroup_wait()
|
|
{
|
|
static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]");
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group<N> without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
|
|
// Marks the commit point for one or more sized batch of warpgroup MMAs.
|
|
CUTE_HOST_DEVICE
|
|
void
|
|
warpgroup_commit_batch()
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
|
|
CUTE_HOST_DEVICE
|
|
void
|
|
warpgroup_fence_operand(uint32_t& reg) {
|
|
// MSVC emits a build error for 'asm volatile'
|
|
// even if it only occurs in a __device__ function.
|
|
// This prevents the error.
|
|
#if defined(__CUDA_ARCH__)
|
|
asm volatile("" : "+r"(reg) :: "memory");
|
|
#endif
|
|
}
|
|
|
|
CUTE_HOST_DEVICE
|
|
void
|
|
warpgroup_fence_operand(float& reg) {
|
|
#if defined(__CUDA_ARCH__)
|
|
asm volatile("" : "+f"(reg) :: "memory");
|
|
#endif
|
|
}
|
|
|
|
namespace GMMA {
|
|
|
|
enum class Major {
|
|
K = 0,
|
|
MN = 1
|
|
};
|
|
|
|
enum class ScaleOut {
|
|
Zero = 0,
|
|
One = 1
|
|
};
|
|
|
|
enum class ScaleIn {
|
|
Neg = -1,
|
|
One = 1
|
|
};
|
|
|
|
} // namespace GMMA
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C)
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x16_F16F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %4, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
|
|
"{%0, %1},"
|
|
" %2,"
|
|
" %3,"
|
|
" p, %5, %6, %7, %8;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x16_F16F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %7, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
|
|
"{%0, %1},"
|
|
"{%2, %3, %4, %5},"
|
|
" %6,"
|
|
" p, %8, %9, %10;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x16_F16F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8, %9, %10;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x16_F16F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11, %12;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x16_F16F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12, %13, %14;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x16_F16F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15, %16;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x16_F16F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20, %21, %22;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x16_F16F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23, %24;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x16_F16F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %26, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
" %24,"
|
|
" %25,"
|
|
" p, %27, %28, %29, %30;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x16_F16F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %29, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
"{%24, %25, %26, %27},"
|
|
" %28,"
|
|
" p, %30, %31, %32;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x16_F16F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36, %37, %38;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x16_F16F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39, %40;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x16_F16F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52, %53, %54;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x16_F16F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55, %56;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x16_F16F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68, %69, %70;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x16 F16+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x16_F16F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71, %72;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x16_F32F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8, %9, %10;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x16_F32F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11, %12;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x16_F32F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12, %13, %14;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x16_F32F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15, %16;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x16_F32F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20, %21, %22;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x16_F32F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23, %24;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x16_F32F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36, %37, %38;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x16_F32F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39, %40;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x16_F32F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52, %53, %54;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x16_F32F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55, %56;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x16_F32F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68, %69, %70;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x16_F32F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71, %72;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x16_F32F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p, %99, %100, %101, %102;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x16_F32F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p, %102, %103, %104;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x16_F32F16F16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p, %131, %132, %133, %134;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x16 F32+=F16*F16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x16_F32F16F16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p, %134, %135, %136;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x16_F32BF16BF16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8, %9, %10;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x16_F32BF16BF16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11, %12;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x16_F32BF16BF16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12, %13, %14;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x16_F32BF16BF16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15, %16;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x16_F32BF16BF16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20, %21, %22;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x16_F32BF16BF16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23, %24;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x16_F32BF16BF16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36, %37, %38;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x16_F32BF16BF16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39, %40;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x16_F32BF16BF16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52, %53, %54;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x16_F32BF16BF16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55, %56;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x16_F32BF16BF16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68, %69, %70;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x16_F32BF16BF16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71, %72;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x16_F32BF16BF16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p, %99, %100, %101, %102;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x16_F32BF16BF16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p, %102, %103, %104;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x16_F32BF16BF16_SS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p, %131, %132, %133, %134;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x16 F32+=BF16*BF16
|
|
template <
|
|
GMMA::Major tnspA,
|
|
GMMA::Major tnspB,
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x16_F32BF16BF16_RS
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
static_assert(tnspA == GMMA::Major::K,
|
|
"Register source operand A must have K major layout.");
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p, %134, %135, %136;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x8_F32TF32TF32_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x8_F32TF32TF32_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x8_F32TF32TF32_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x8_F32TF32TF32_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x8_F32TF32TF32_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x8_F32TF32TF32_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x8_F32TF32TF32_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x8_F32TF32TF32_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x8_F32TF32TF32_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x8_F32TF32TF32_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x8_F32TF32TF32_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x8_F32TF32TF32_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x8_F32TF32TF32_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p, %99, %100;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x8_F32TF32TF32_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p, %102, %103;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x8_F32TF32TF32_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p, %131, %132;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x8 TN F32+=TF32*TF32
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x8_F32TF32TF32_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p, %134, %135;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=S8*S8
|
|
struct SM90_64x8x32_S32S8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=S8*S8
|
|
struct SM90_64x8x32_S32S8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=S8*S8
|
|
struct SM90_64x16x32_S32S8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=S8*S8
|
|
struct SM90_64x16x32_S32S8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=S8*S8
|
|
struct SM90_64x32x32_S32S8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=S8*S8
|
|
struct SM90_64x32x32_S32S8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=S8*S8
|
|
struct SM90_64x64x32_S32S8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=S8*S8
|
|
struct SM90_64x64x32_S32S8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=S8*S8
|
|
struct SM90_64x96x32_S32S8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=S8*S8
|
|
struct SM90_64x96x32_S32S8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=S8*S8
|
|
struct SM90_64x128x32_S32S8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=S8*S8
|
|
struct SM90_64x128x32_S32S8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=S8*S8
|
|
struct SM90_64x192x32_S32S8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=S8*S8
|
|
struct SM90_64x192x32_S32S8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=S8*S8
|
|
struct SM90_64x256x32_S32S8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=S8*S8
|
|
struct SM90_64x256x32_S32S8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=S8*S8
|
|
struct SM90_64x8x32_S32S8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=S8*S8
|
|
struct SM90_64x8x32_S32S8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=S8*S8
|
|
struct SM90_64x16x32_S32S8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=S8*S8
|
|
struct SM90_64x16x32_S32S8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=S8*S8
|
|
struct SM90_64x32x32_S32S8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=S8*S8
|
|
struct SM90_64x32x32_S32S8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=S8*S8
|
|
struct SM90_64x64x32_S32S8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=S8*S8
|
|
struct SM90_64x64x32_S32S8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=S8*S8
|
|
struct SM90_64x96x32_S32S8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=S8*S8
|
|
struct SM90_64x96x32_S32S8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=S8*S8
|
|
struct SM90_64x128x32_S32S8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=S8*S8
|
|
struct SM90_64x128x32_S32S8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=S8*S8
|
|
struct SM90_64x192x32_S32S8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=S8*S8
|
|
struct SM90_64x192x32_S32S8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=S8*S8
|
|
struct SM90_64x256x32_S32S8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=S8*S8
|
|
struct SM90_64x256x32_S32S8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=S8*U8
|
|
struct SM90_64x8x32_S32S8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=S8*U8
|
|
struct SM90_64x8x32_S32S8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=S8*U8
|
|
struct SM90_64x16x32_S32S8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=S8*U8
|
|
struct SM90_64x16x32_S32S8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=S8*U8
|
|
struct SM90_64x32x32_S32S8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=S8*U8
|
|
struct SM90_64x32x32_S32S8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=S8*U8
|
|
struct SM90_64x64x32_S32S8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=S8*U8
|
|
struct SM90_64x64x32_S32S8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=S8*U8
|
|
struct SM90_64x96x32_S32S8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=S8*U8
|
|
struct SM90_64x96x32_S32S8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=S8*U8
|
|
struct SM90_64x128x32_S32S8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=S8*U8
|
|
struct SM90_64x128x32_S32S8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=S8*U8
|
|
struct SM90_64x192x32_S32S8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=S8*U8
|
|
struct SM90_64x192x32_S32S8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=S8*U8
|
|
struct SM90_64x256x32_S32S8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=S8*U8
|
|
struct SM90_64x256x32_S32S8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=S8*U8
|
|
struct SM90_64x8x32_S32S8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=S8*U8
|
|
struct SM90_64x8x32_S32S8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=S8*U8
|
|
struct SM90_64x16x32_S32S8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=S8*U8
|
|
struct SM90_64x16x32_S32S8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=S8*U8
|
|
struct SM90_64x32x32_S32S8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=S8*U8
|
|
struct SM90_64x32x32_S32S8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=S8*U8
|
|
struct SM90_64x64x32_S32S8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=S8*U8
|
|
struct SM90_64x64x32_S32S8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=S8*U8
|
|
struct SM90_64x96x32_S32S8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=S8*U8
|
|
struct SM90_64x96x32_S32S8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=S8*U8
|
|
struct SM90_64x128x32_S32S8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=S8*U8
|
|
struct SM90_64x128x32_S32S8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=S8*U8
|
|
struct SM90_64x192x32_S32S8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=S8*U8
|
|
struct SM90_64x192x32_S32S8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=S8*U8
|
|
struct SM90_64x256x32_S32S8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=S8*U8
|
|
struct SM90_64x256x32_S32S8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=U8*S8
|
|
struct SM90_64x8x32_S32U8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=U8*S8
|
|
struct SM90_64x8x32_S32U8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=U8*S8
|
|
struct SM90_64x16x32_S32U8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=U8*S8
|
|
struct SM90_64x16x32_S32U8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=U8*S8
|
|
struct SM90_64x32x32_S32U8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=U8*S8
|
|
struct SM90_64x32x32_S32U8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=U8*S8
|
|
struct SM90_64x64x32_S32U8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=U8*S8
|
|
struct SM90_64x64x32_S32U8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=U8*S8
|
|
struct SM90_64x96x32_S32U8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=U8*S8
|
|
struct SM90_64x96x32_S32U8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=U8*S8
|
|
struct SM90_64x128x32_S32U8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=U8*S8
|
|
struct SM90_64x128x32_S32U8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=U8*S8
|
|
struct SM90_64x192x32_S32U8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=U8*S8
|
|
struct SM90_64x192x32_S32U8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=U8*S8
|
|
struct SM90_64x256x32_S32U8S8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=U8*S8
|
|
struct SM90_64x256x32_S32U8S8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=U8*S8
|
|
struct SM90_64x8x32_S32U8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=U8*S8
|
|
struct SM90_64x8x32_S32U8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=U8*S8
|
|
struct SM90_64x16x32_S32U8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=U8*S8
|
|
struct SM90_64x16x32_S32U8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=U8*S8
|
|
struct SM90_64x32x32_S32U8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=U8*S8
|
|
struct SM90_64x32x32_S32U8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=U8*S8
|
|
struct SM90_64x64x32_S32U8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=U8*S8
|
|
struct SM90_64x64x32_S32U8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=U8*S8
|
|
struct SM90_64x96x32_S32U8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=U8*S8
|
|
struct SM90_64x96x32_S32U8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=U8*S8
|
|
struct SM90_64x128x32_S32U8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=U8*S8
|
|
struct SM90_64x128x32_S32U8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=U8*S8
|
|
struct SM90_64x192x32_S32U8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=U8*S8
|
|
struct SM90_64x192x32_S32U8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=U8*S8
|
|
struct SM90_64x256x32_S32U8S8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=U8*S8
|
|
struct SM90_64x256x32_S32U8S8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=U8*U8
|
|
struct SM90_64x8x32_S32U8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=U8*U8
|
|
struct SM90_64x8x32_S32U8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=U8*U8
|
|
struct SM90_64x16x32_S32U8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=U8*U8
|
|
struct SM90_64x16x32_S32U8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=U8*U8
|
|
struct SM90_64x32x32_S32U8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=U8*U8
|
|
struct SM90_64x32x32_S32U8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=U8*U8
|
|
struct SM90_64x64x32_S32U8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=U8*U8
|
|
struct SM90_64x64x32_S32U8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=U8*U8
|
|
struct SM90_64x96x32_S32U8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=U8*U8
|
|
struct SM90_64x96x32_S32U8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=U8*U8
|
|
struct SM90_64x128x32_S32U8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=U8*U8
|
|
struct SM90_64x128x32_S32U8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=U8*U8
|
|
struct SM90_64x192x32_S32U8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=U8*U8
|
|
struct SM90_64x192x32_S32U8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=U8*U8
|
|
struct SM90_64x256x32_S32U8U8_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=U8*U8
|
|
struct SM90_64x256x32_S32U8U8_SS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=U8*U8
|
|
struct SM90_64x8x32_S32U8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN S32+=U8*U8
|
|
struct SM90_64x8x32_S32U8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=U8*U8
|
|
struct SM90_64x16x32_S32U8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN S32+=U8*U8
|
|
struct SM90_64x16x32_S32U8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=U8*U8
|
|
struct SM90_64x32x32_S32U8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN S32+=U8*U8
|
|
struct SM90_64x32x32_S32U8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=U8*U8
|
|
struct SM90_64x64x32_S32U8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN S32+=U8*U8
|
|
struct SM90_64x64x32_S32U8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=U8*U8
|
|
struct SM90_64x96x32_S32U8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN S32+=U8*U8
|
|
struct SM90_64x96x32_S32U8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=U8*U8
|
|
struct SM90_64x128x32_S32U8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN S32+=U8*U8
|
|
struct SM90_64x128x32_S32U8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=U8*U8
|
|
struct SM90_64x192x32_S32U8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN S32+=U8*U8
|
|
struct SM90_64x192x32_S32U8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67,
|
|
uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71,
|
|
uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75,
|
|
uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79,
|
|
uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83,
|
|
uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87,
|
|
uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91,
|
|
uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63),
|
|
"+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67),
|
|
"+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71),
|
|
"+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75),
|
|
"+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79),
|
|
"+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83),
|
|
"+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87),
|
|
"+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91),
|
|
"+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=U8*U8
|
|
struct SM90_64x256x32_S32U8U8_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN S32+=U8*U8
|
|
struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003,
|
|
uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007,
|
|
uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011,
|
|
uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015,
|
|
uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019,
|
|
uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023,
|
|
uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027,
|
|
uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031,
|
|
uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035,
|
|
uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039,
|
|
uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043,
|
|
uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047,
|
|
uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051,
|
|
uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055,
|
|
uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059,
|
|
uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063,
|
|
uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067,
|
|
uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071,
|
|
uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075,
|
|
uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079,
|
|
uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083,
|
|
uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087,
|
|
uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091,
|
|
uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095,
|
|
uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099,
|
|
uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103,
|
|
uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107,
|
|
uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111,
|
|
uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115,
|
|
uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119,
|
|
uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123,
|
|
uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p;\n"
|
|
"}\n"
|
|
: "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003),
|
|
"+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007),
|
|
"+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011),
|
|
"+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015),
|
|
"+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019),
|
|
"+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023),
|
|
"+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027),
|
|
"+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031),
|
|
"+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035),
|
|
"+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039),
|
|
"+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043),
|
|
"+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047),
|
|
"+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051),
|
|
"+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055),
|
|
"+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059),
|
|
"+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063),
|
|
"+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067),
|
|
"+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071),
|
|
"+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075),
|
|
"+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079),
|
|
"+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083),
|
|
"+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087),
|
|
"+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091),
|
|
"+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095),
|
|
"+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099),
|
|
"+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103),
|
|
"+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107),
|
|
"+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111),
|
|
"+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115),
|
|
"+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119),
|
|
"+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123),
|
|
"+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F16E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %4, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 "
|
|
"{%0, %1},"
|
|
" %2,"
|
|
" %3,"
|
|
" p, %5, %6;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F16E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %7, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 "
|
|
"{%0, %1},"
|
|
"{%2, %3, %4, %5},"
|
|
" %6,"
|
|
" p, %8, %9;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F32E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F32E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F16E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F16E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F32E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F32E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F16E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F16E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F32E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F32E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F16E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F16E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F32E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F32E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F16E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %26, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
" %24,"
|
|
" %25,"
|
|
" p, %27, %28;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F16E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %29, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
"{%24, %25, %26, %27},"
|
|
" %28,"
|
|
" p, %30, %31;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F32E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F32E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F16E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F16E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F32E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F32E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F16E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F16E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F32E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p, %99, %100;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F32E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p, %102, %103;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F16E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F16+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F16E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F32E4M3E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p, %131, %132;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F32+=E4M3*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F32E4M3E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p, %134, %135;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F16E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %4, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 "
|
|
"{%0, %1},"
|
|
" %2,"
|
|
" %3,"
|
|
" p, %5, %6;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F16E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %7, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 "
|
|
"{%0, %1},"
|
|
"{%2, %3, %4, %5},"
|
|
" %6,"
|
|
" p, %8, %9;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F32E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F32E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F16E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F16E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F32E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F32E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F16E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F16E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F32E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F32E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F16E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F16E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F32E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F32E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F16E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %26, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
" %24,"
|
|
" %25,"
|
|
" p, %27, %28;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F16E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %29, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
"{%24, %25, %26, %27},"
|
|
" %28,"
|
|
" p, %30, %31;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F32E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F32E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F16E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F16E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F32E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F32E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F16E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F16E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F32E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p, %99, %100;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F32E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p, %102, %103;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F16E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F16+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F16E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F32E4M3E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p, %131, %132;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F32+=E4M3*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F32E4M3E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p, %134, %135;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F16E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %4, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 "
|
|
"{%0, %1},"
|
|
" %2,"
|
|
" %3,"
|
|
" p, %5, %6;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F16E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %7, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 "
|
|
"{%0, %1},"
|
|
"{%2, %3, %4, %5},"
|
|
" %6,"
|
|
" p, %8, %9;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F32E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F32E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F16E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F16E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F32E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F32E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F16E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F16E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F32E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F32E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F16E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F16E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F32E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F32E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F16E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %26, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
" %24,"
|
|
" %25,"
|
|
" p, %27, %28;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F16E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %29, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
"{%24, %25, %26, %27},"
|
|
" %28,"
|
|
" p, %30, %31;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F32E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F32E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F16E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F16E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F32E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F32E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F16E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F16E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F32E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p, %99, %100;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F32E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p, %102, %103;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F16E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F16+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F16E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F32E5M2E4M3_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p, %131, %132;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F32+=E5M2*E4M3
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F32E5M2E4M3_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p, %134, %135;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F16E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %4, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 "
|
|
"{%0, %1},"
|
|
" %2,"
|
|
" %3,"
|
|
" p, %5, %6;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F16E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[2];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %7, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 "
|
|
"{%0, %1},"
|
|
"{%2, %3, %4, %5},"
|
|
" %6,"
|
|
" p, %8, %9;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F32E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x8x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x8x32_F32E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F16E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %6, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3},"
|
|
" %4,"
|
|
" %5,"
|
|
" p, %7, %8;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F16E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[4];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %9, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3},"
|
|
"{%4, %5, %6, %7},"
|
|
" %8,"
|
|
" p, %10, %11;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F32E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x16x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x16x32_F32E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
float & d0, float & d1, float & d2, float & d3,
|
|
float & d4, float & d5, float & d6, float & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15;\n"
|
|
"}\n"
|
|
: "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3),
|
|
"+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F16E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %10, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
" %8,"
|
|
" %9,"
|
|
" p, %11, %12;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F16E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[8];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
|
|
uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %13, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7},"
|
|
"{%8, %9, %10, %11},"
|
|
" %12,"
|
|
" p, %14, %15;\n"
|
|
"}\n"
|
|
: "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3),
|
|
"+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7)
|
|
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F32E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x32x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x32x32_F32E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F16E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %18, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
" %16,"
|
|
" %17,"
|
|
" p, %19, %20;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F16E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[16];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %21, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15},"
|
|
"{%16, %17, %18, %19},"
|
|
" %20,"
|
|
" p, %22, %23;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F32E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x64x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x64x32_F32E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F16E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %26, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
" %24,"
|
|
" %25,"
|
|
" p, %27, %28;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F16E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[24];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %29, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23},"
|
|
"{%24, %25, %26, %27},"
|
|
" %28,"
|
|
" p, %30, %31;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F32E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x96x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x96x32_F32E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F16E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %34, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
" %32,"
|
|
" %33,"
|
|
" p, %35, %36;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F16E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[32];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %37, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31},"
|
|
"{%32, %33, %34, %35},"
|
|
" %36,"
|
|
" p, %38, %39;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F32E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x128x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x128x32_F32E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F16E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %50, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
" %48,"
|
|
" %49,"
|
|
" p, %51, %52;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F16E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[48];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %53, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47},"
|
|
"{%48, %49, %50, %51},"
|
|
" %52,"
|
|
" p, %54, %55;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F32E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %98, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
" %96,"
|
|
" %97,"
|
|
" p, %99, %100;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x192x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x192x32_F32E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[96];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
float & d00, float & d01, float & d02, float & d03,
|
|
float & d04, float & d05, float & d06, float & d07,
|
|
float & d08, float & d09, float & d10, float & d11,
|
|
float & d12, float & d13, float & d14, float & d15,
|
|
float & d16, float & d17, float & d18, float & d19,
|
|
float & d20, float & d21, float & d22, float & d23,
|
|
float & d24, float & d25, float & d26, float & d27,
|
|
float & d28, float & d29, float & d30, float & d31,
|
|
float & d32, float & d33, float & d34, float & d35,
|
|
float & d36, float & d37, float & d38, float & d39,
|
|
float & d40, float & d41, float & d42, float & d43,
|
|
float & d44, float & d45, float & d46, float & d47,
|
|
float & d48, float & d49, float & d50, float & d51,
|
|
float & d52, float & d53, float & d54, float & d55,
|
|
float & d56, float & d57, float & d58, float & d59,
|
|
float & d60, float & d61, float & d62, float & d63,
|
|
float & d64, float & d65, float & d66, float & d67,
|
|
float & d68, float & d69, float & d70, float & d71,
|
|
float & d72, float & d73, float & d74, float & d75,
|
|
float & d76, float & d77, float & d78, float & d79,
|
|
float & d80, float & d81, float & d82, float & d83,
|
|
float & d84, float & d85, float & d86, float & d87,
|
|
float & d88, float & d89, float & d90, float & d91,
|
|
float & d92, float & d93, float & d94, float & d95,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %101, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95},"
|
|
"{%96, %97, %98, %99},"
|
|
" %100,"
|
|
" p, %102, %103;\n"
|
|
"}\n"
|
|
: "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03),
|
|
"+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07),
|
|
"+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11),
|
|
"+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15),
|
|
"+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19),
|
|
"+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23),
|
|
"+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27),
|
|
"+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31),
|
|
"+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35),
|
|
"+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39),
|
|
"+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43),
|
|
"+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47),
|
|
"+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51),
|
|
"+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55),
|
|
"+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59),
|
|
"+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63),
|
|
"+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67),
|
|
"+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71),
|
|
"+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75),
|
|
"+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79),
|
|
"+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83),
|
|
"+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87),
|
|
"+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91),
|
|
"+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F16E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %66, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
" %64,"
|
|
" %65,"
|
|
" p, %67, %68;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F16+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F16E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = uint32_t[64];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03,
|
|
uint64_t const& desc_b,
|
|
uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03,
|
|
uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07,
|
|
uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11,
|
|
uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15,
|
|
uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19,
|
|
uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23,
|
|
uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27,
|
|
uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31,
|
|
uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35,
|
|
uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39,
|
|
uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43,
|
|
uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47,
|
|
uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51,
|
|
uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55,
|
|
uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59,
|
|
uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %69, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63},"
|
|
"{%64, %65, %66, %67},"
|
|
" %68,"
|
|
" p, %70, %71;\n"
|
|
"}\n"
|
|
: "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03),
|
|
"+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07),
|
|
"+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11),
|
|
"+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15),
|
|
"+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19),
|
|
"+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23),
|
|
"+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27),
|
|
"+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31),
|
|
"+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35),
|
|
"+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39),
|
|
"+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43),
|
|
"+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47),
|
|
"+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51),
|
|
"+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55),
|
|
"+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59),
|
|
"+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63)
|
|
: "r"(a00), "r"(a01), "r"(a02), "r"(a03),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F32E5M2E5M2_SS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint64_t[1];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint64_t const& desc_a,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %130, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
" %128,"
|
|
" %129,"
|
|
" p, %131, %132;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "l"(desc_a),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// GMMA 64x256x32 TN F32+=E5M2*E5M2
|
|
template <
|
|
GMMA::ScaleIn scaleA = GMMA::ScaleIn::One,
|
|
GMMA::ScaleIn scaleB = GMMA::ScaleIn::One
|
|
>
|
|
struct SM90_64x256x32_F32E5M2E5M2_RS_TN
|
|
{
|
|
using DRegisters = void;
|
|
using ARegisters = uint32_t[4];
|
|
using BRegisters = uint64_t[1];
|
|
using CRegisters = float[128];
|
|
|
|
CUTE_HOST_DEVICE static void
|
|
fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003,
|
|
uint64_t const& desc_b,
|
|
float & d000, float & d001, float & d002, float & d003,
|
|
float & d004, float & d005, float & d006, float & d007,
|
|
float & d008, float & d009, float & d010, float & d011,
|
|
float & d012, float & d013, float & d014, float & d015,
|
|
float & d016, float & d017, float & d018, float & d019,
|
|
float & d020, float & d021, float & d022, float & d023,
|
|
float & d024, float & d025, float & d026, float & d027,
|
|
float & d028, float & d029, float & d030, float & d031,
|
|
float & d032, float & d033, float & d034, float & d035,
|
|
float & d036, float & d037, float & d038, float & d039,
|
|
float & d040, float & d041, float & d042, float & d043,
|
|
float & d044, float & d045, float & d046, float & d047,
|
|
float & d048, float & d049, float & d050, float & d051,
|
|
float & d052, float & d053, float & d054, float & d055,
|
|
float & d056, float & d057, float & d058, float & d059,
|
|
float & d060, float & d061, float & d062, float & d063,
|
|
float & d064, float & d065, float & d066, float & d067,
|
|
float & d068, float & d069, float & d070, float & d071,
|
|
float & d072, float & d073, float & d074, float & d075,
|
|
float & d076, float & d077, float & d078, float & d079,
|
|
float & d080, float & d081, float & d082, float & d083,
|
|
float & d084, float & d085, float & d086, float & d087,
|
|
float & d088, float & d089, float & d090, float & d091,
|
|
float & d092, float & d093, float & d094, float & d095,
|
|
float & d096, float & d097, float & d098, float & d099,
|
|
float & d100, float & d101, float & d102, float & d103,
|
|
float & d104, float & d105, float & d106, float & d107,
|
|
float & d108, float & d109, float & d110, float & d111,
|
|
float & d112, float & d113, float & d114, float & d115,
|
|
float & d116, float & d117, float & d118, float & d119,
|
|
float & d120, float & d121, float & d122, float & d123,
|
|
float & d124, float & d125, float & d126, float & d127,
|
|
GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One)
|
|
{
|
|
#if defined(CUTE_ARCH_MMA_SM90A_ENABLED)
|
|
asm volatile(
|
|
"{\n"
|
|
".reg .pred p;\n"
|
|
"setp.ne.b32 p, %133, 0;\n"
|
|
"wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 "
|
|
"{%0, %1, %2, %3, %4, %5, %6, %7, "
|
|
" %8, %9, %10, %11, %12, %13, %14, %15, "
|
|
" %16, %17, %18, %19, %20, %21, %22, %23, "
|
|
" %24, %25, %26, %27, %28, %29, %30, %31, "
|
|
" %32, %33, %34, %35, %36, %37, %38, %39, "
|
|
" %40, %41, %42, %43, %44, %45, %46, %47, "
|
|
" %48, %49, %50, %51, %52, %53, %54, %55, "
|
|
" %56, %57, %58, %59, %60, %61, %62, %63, "
|
|
" %64, %65, %66, %67, %68, %69, %70, %71, "
|
|
" %72, %73, %74, %75, %76, %77, %78, %79, "
|
|
" %80, %81, %82, %83, %84, %85, %86, %87, "
|
|
" %88, %89, %90, %91, %92, %93, %94, %95, "
|
|
" %96, %97, %98, %99, %100, %101, %102, %103, "
|
|
" %104, %105, %106, %107, %108, %109, %110, %111, "
|
|
" %112, %113, %114, %115, %116, %117, %118, %119, "
|
|
" %120, %121, %122, %123, %124, %125, %126, %127},"
|
|
"{%128, %129, %130, %131},"
|
|
" %132,"
|
|
" p, %134, %135;\n"
|
|
"}\n"
|
|
: "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003),
|
|
"+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007),
|
|
"+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011),
|
|
"+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015),
|
|
"+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019),
|
|
"+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023),
|
|
"+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027),
|
|
"+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031),
|
|
"+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035),
|
|
"+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039),
|
|
"+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043),
|
|
"+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047),
|
|
"+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051),
|
|
"+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055),
|
|
"+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059),
|
|
"+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063),
|
|
"+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067),
|
|
"+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071),
|
|
"+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075),
|
|
"+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079),
|
|
"+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083),
|
|
"+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087),
|
|
"+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091),
|
|
"+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095),
|
|
"+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099),
|
|
"+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103),
|
|
"+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107),
|
|
"+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111),
|
|
"+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115),
|
|
"+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119),
|
|
"+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123),
|
|
"+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127)
|
|
: "r"(a000), "r"(a001), "r"(a002), "r"(a003),
|
|
"l"(desc_b),
|
|
"r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)));
|
|
#else
|
|
CUTE_INVALID_CONTROL_PATH("Attempting to use SM90_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED");
|
|
#endif
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace cute
|