cutlass/include/cute/algorithm/cooperative_gemm.hpp

530 lines
22 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/util/type_traits.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/algorithm/axpby.hpp>
#include <cute/algorithm/functional.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cute/tensor.hpp>
namespace cute
{
//
// Cooperative Shared-Memory GEMMs
//
namespace detail {
// Predicated Cooperative GEMM
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
class CLoadTransformOp, class CStoreTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
{
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
// Original, static size of the problem
auto M = size<0>(sC);
auto N = size<1>(sC);
auto K = size<1>(sA);
// Block size of the compute tile
auto BLK_M = tile_size<0>(thr_mma);
auto BLK_N = tile_size<1>(thr_mma);
auto BLK_K = tile_size<2>(thr_mma);
//
// MMA Partitioning
//
// Round the layout extents up to BLK_X to satisfy MMA partitioning safety
Tensor rounded_sA = sA.compose(make_shape(round_up(M, BLK_M), round_up(K, BLK_K)));
Tensor rounded_sB = sB.compose(make_shape(round_up(N, BLK_N), round_up(K, BLK_K)));
Tensor rounded_sC = sC.compose(make_shape(round_up(M, BLK_M), round_up(N, BLK_N)));
// Partition the sA and sB tiles across the threads for the MMA
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
#if 0
if (thread0()) {
print(" sA: "); print( sA); print("\n");
print(" sB: "); print( sB); print("\n");
print(" sC: "); print( sC); print("\n");
print("r_sA: "); print(rounded_sA); print("\n");
print("r_sB: "); print(rounded_sB); print("\n");
print("r_sC: "); print(rounded_sC); print("\n");
print(thr_mma);
print("tCsA: "); print(tCsA); print("\n");
print("tCsB: "); print(tCsB); print("\n");
print("tCsC: "); print(tCsC); print("\n");
print("tCrA: "); print(tCrA); print("\n");
print("tCrB: "); print(tCrB); print("\n");
print("tCrC: "); print(tCrC); print("\n");
}
#endif
//
// PREDICATION
//
// Create coordinate tensors for the problem
Tensor cA = make_identity_tensor(shape(rounded_sA)); // (M,K) -> (m,k)
Tensor cB = make_identity_tensor(shape(rounded_sB)); // (N,K) -> (n,k)
// Repeat partitioning with thr_mma
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k)
Tensor tCcB = thr_mma.partition_B(cB); // (MMA,MMA_N,MMA_K) -> (n,k)
// Allocate the preds for MMA- and MMA_MN-modes
Tensor tCpA = make_tensor<bool>(make_shape(size<0>(tCsA), size<1>(tCsA)));
Tensor tCpB = make_tensor<bool>(make_shape(size<0>(tCsB), size<1>(tCsB)));
// Populate the predicates on M and N
CUTE_UNROLL
for (int i = 0; i < size(tCpA); ++i) {
tCpA(i) = elem_less(get<0>(tCcA(_,_,Int<0>{})(i)), shape<0>(sA));
}
CUTE_UNROLL
for (int i = 0; i < size(tCpB); ++i) {
tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB));
}
#if 0
if (thread0()) {
print(" cA: "); print( cA); print("\n");
print(" cB: "); print( cB); print("\n");
print("tCcA: "); print(tCcA); print("\n");
print("tCcB: "); print(tCcB); print("\n");
print_tensor(tCpA);
print_tensor(tCpB);
}
#endif
//
// PREFETCH k_block = 0
// Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block
// Assumes the MMA-tiling in K is trivial
//
constexpr int K_BLOCK_MAX = size<2>(tCrA);
CUTE_UNROLL
for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M
CUTE_UNROLL
for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I
tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
}
}
CUTE_UNROLL
for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N
CUTE_UNROLL
for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I
tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
}
}
//
// MAINLOOP
//
// Clear accumulators
clear(tCrC);
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
if (k_block < K_BLOCK_MAX-1) // static-if not the last k_block
{
int k_next = k_block + 1; // Load k_next block
// Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block
// Assumes the MMA-tiling in K is trivial
CUTE_UNROLL
for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M
CUTE_UNROLL
for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I
tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
}
}
CUTE_UNROLL
for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N
CUTE_UNROLL
for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I
tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
}
}
}
// GEMM on k_block in registers
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
//
// Epilogue
//
// Create coordinate tensors for the problem
Tensor cC = make_identity_tensor(shape(rounded_sC)); // (M,N) -> (m,n)
// Repeat partitioning with thr_mma
Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n)
const bool isBetaZero = (beta == Beta{});
// Custom axpby_if for now
CUTE_UNROLL
for (int i = 0; i < size(tCrC); ++i)
{
if (elem_less(tCcC(i), shape(sC)))
{
tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast<TypeC>(tCrC(i))
: alpha * static_cast<TypeC>(tCrC(i)) +
beta * static_cast<TypeC>(sC_load_op(tCsC(i))));
}
}
}
// Slow fallback path
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
class CLoadTransformOp, class CStoreTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm_predication(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
{
// ThrMMA
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op);
}
// Unpredicated Cooperative GEMM
template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
class CLoadTransformOp, class CStoreTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm_no_predication(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
{
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
// ThrMMA
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
//
// MMA Partitioning
//
Tensor tCsC = thr_mma.partition_C(sC);
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
using CopyOpAType = SmemCopyOpA;
using CopyOpBType = SmemCopyOpB;
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, TypeA>{}, thr_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
Tensor tCsA = smem_thr_copy_A.partition_S(sA);
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K
auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, TypeB>{}, thr_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
Tensor tCsB = smem_thr_copy_B.partition_S(sB);
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K
#if 0
if (thread0()) {
print(" sA: "); print(sA); print("\n");
print(" sB: "); print(sB); print("\n");
print(" sC: "); print(sC); print("\n");
print(thr_mma); print("\n");
print("tCsC: "); print(tCsC); print("\n");
print("tCrA: "); print(tCrA); print("\n");
print("tCrB: "); print(tCrB); print("\n");
print("tCrC: "); print(tCrC); print("\n");
print(smem_thr_copy_A); print("\n");
print("tCsA: "); print(tCsA); print("\n");
print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n");
print(smem_thr_copy_B); print("\n");
print("tCsB: "); print(tCsB); print("\n");
print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n");
}
#endif
//
// PREFETCH
//
copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{}));
copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
//
// MAINLOOP
//
// Clear accumulators
clear(tCrC);
constexpr int K_BLOCK_MAX = size<2>(tCrA);
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
// static-if load the next k_block. No k-predication required on these loads.
if (k_block < K_BLOCK_MAX-1)
{
// Load the next k_block
int k_next = k_block + 1; // statically unrolled
copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next));
copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next));
}
// Transform A and B, relying on the compiler to remove in case of identity ops
cute::transform(tCrA(_,_,k_block), sA_load_op);
cute::transform(tCrB(_,_,k_block), sB_load_op);
// GEMM on k_block in registers
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
//
// Epilogue
//
auto isBetaZero = [&] () {
if constexpr (is_complex<Beta>::value) {
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
}
else {
return beta == Int<0>{};
}
CUTE_GCC_UNREACHABLE;
} ();
using CopyOpCType = SmemCopyOpC;
Tensor tCrD = thr_mma.make_fragment_C(tCsC);
if(!isBetaZero) {
copy(CopyOpCType{}, tCsC, tCrD);
// Transform C on/after load
cute::transform(tCrD, sC_load_op);
}
// C = alpha * (A * B) + beta * C
axpby(alpha, tCrC, beta, tCrD);
// Transform C before/on store
cute::transform(tCrD, sC_store_op);
copy(CopyOpCType{}, tCrD, tCsC);
}
} // end namespace detail
template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
{
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
static_assert(is_convertible_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
"ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type");
static_assert(is_convertible_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
"BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type");
static_assert(is_convertible_v<decay_t<invoke_result_t<CLoadTransformOp, TypeC>>, TypeC>,
"CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
static_assert(is_convertible_v<decay_t<invoke_result_t<CStoreTransformOp, TypeC>>, TypeC>,
"CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
static constexpr bool compat = weakly_compatible(tile_shape(TiledMMA<Args...>{}),
make_shape(size<0>(sA), size<0>(sB), size<1>(sA)));
if constexpr (compat) {
detail::cooperative_gemm_no_predication<SmemCopyOpA, SmemCopyOpB, SmemCopyOpC>(
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op
);
} else {
detail::cooperative_gemm_predication(
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op
);
}
}
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
{
using CopyOpA = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TA::value_type>>;
using CopyOpB = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TB::value_type>>;
using CopyOpC = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TC::value_type>>;
cooperative_gemm<CopyOpA, CopyOpB, CopyOpC>(
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op
);
}
// Legacy overload of cute::gemm for backwards-compatibility
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
{
// Goes directly to the slow path to avoid getting thread_idx from thr_mma
detail::cooperative_gemm_predication(
thr_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op
);
}
} // end namespace cute