cutlass/examples/52_hopper_gather_scatter_fu.../scatter_epilogue.hpp

223 lines
8.3 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cute/tensor.hpp"
#include "cute/numeric/numeric_types.hpp"
#include "gather_tensor.hpp"
namespace cutlass::epilogue::collective {
/// Applies an element wise operation to all elements within the fragment
/// and scatter-writes them out to destination storage.
/// GatherC and ScatterD are types of user-defined functions that apply the
/// transoformation of the strided coordinate (e.g. through an index array).
template <
class StrideC_,
class StrideD_,
class ThreadEpilogueOp_,
class EpilogueSchedule_,
class GatherC_,
class ScatterD_
>
class EpilogueGatherScatter {
public:
//
// Type Aliases
//
using EpilogueSchedule = EpilogueSchedule_;
// derived types of output thread level operator
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementScalar = ElementCompute;
using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_;
using ElementD = typename ThreadEpilogueOp::ElementD;
using StrideD = StrideD_;
// Every epilogue needs these two GmemTiledCopy{C,D} aliases.
// If you don't know what they should be, just use void.
using GmemTiledCopyC = void;
using GmemTiledCopyD = void;
using GatherC = GatherC_;
using ScatterD = ScatterD_;
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { };
// Host side epilogue arguments
struct Arguments {
typename ThreadEpilogueOp::Params thread_params{};
ElementC const* ptr_C = nullptr;
StrideC dC{};
ElementD* ptr_D = nullptr;
StrideD dD{};
GatherC gather_C{};
ScatterD scatter_D{};
};
// Device side epilogue params
using Params = Arguments;
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(
[[maybe_unused]] ProblemShape const& _,
Arguments const& args,
[[maybe_unused]] void* workspace) {
return args;
}
template<class ProblemShape>
CUTLASS_HOST_DEVICE static bool
can_implement(
[[maybe_unused]] ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
return true;
}
CUTLASS_HOST_DEVICE
EpilogueGatherScatter(Params const& params_) : params(params_) { }
template<
class ProblemShapeMNKL,
class BlockShapeMNK,
class BlockCoordMNKL,
class FrgEngine, class FrgLayout,
class TiledMma,
class ResidueMNK
>
CUTLASS_DEVICE void
operator()(
ProblemShapeMNKL problem_shape_mnkl,
BlockShapeMNK blk_shape_MNK,
BlockCoordMNKL blk_coord_mnkl,
cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
TiledMma tiled_mma,
ResidueMNK residue_mnk,
int thread_idx,
char* smem_buf)
{
using namespace cute;
using X = Underscore;
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
(void) smem_buf;
ThreadEpilogueOp epilogue_op{params.thread_params};
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);
auto stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC);
auto stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD);
// Represent the full output tensor
Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l)
Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l)
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
// Partition source and destination tiles to match the accumulator partitioning
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N)
Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N)
static_assert(is_static<FrgLayout>::value, "Accumulator layout must be static");
CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD),
"Source and destination must have the same number of elements.");
CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators),
"Accumulator count must have the same destination element count.");
// Make an identity coordinate tensor for predicating our output MN tile
auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Tensor tCcD = thr_mma.partition_C(cD);
// source is needed
if (epilogue_op.is_source_needed()) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators); ++i) {
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
tCgD(i) = epilogue_op(accumulators(i), tCgC(i));
}
}
}
// source is not needed, avoid load
else {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators); ++i) {
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
tCgD(i) = epilogue_op(accumulators(i));
}
}
}
}
private:
Params params;
};
} // namespace cutlass::epilogue::collective