223 lines
8.3 KiB
C++
223 lines
8.3 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/*! \file
|
|
\brief Functor performing elementwise operations used by epilogues.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
#include "cutlass/epilogue/collective/detail.hpp"
|
|
|
|
#include "cute/tensor.hpp"
|
|
#include "cute/numeric/numeric_types.hpp"
|
|
|
|
#include "gather_tensor.hpp"
|
|
|
|
namespace cutlass::epilogue::collective {
|
|
|
|
/// Applies an element wise operation to all elements within the fragment
|
|
/// and scatter-writes them out to destination storage.
|
|
/// GatherC and ScatterD are types of user-defined functions that apply the
|
|
/// transoformation of the strided coordinate (e.g. through an index array).
|
|
template <
|
|
class StrideC_,
|
|
class StrideD_,
|
|
class ThreadEpilogueOp_,
|
|
class EpilogueSchedule_,
|
|
class GatherC_,
|
|
class ScatterD_
|
|
>
|
|
class EpilogueGatherScatter {
|
|
public:
|
|
//
|
|
// Type Aliases
|
|
//
|
|
using EpilogueSchedule = EpilogueSchedule_;
|
|
|
|
// derived types of output thread level operator
|
|
using ThreadEpilogueOp = ThreadEpilogueOp_;
|
|
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
|
|
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
|
|
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
|
|
using ElementScalar = ElementCompute;
|
|
using ElementC = typename ThreadEpilogueOp::ElementC;
|
|
using StrideC = StrideC_;
|
|
using ElementD = typename ThreadEpilogueOp::ElementD;
|
|
using StrideD = StrideD_;
|
|
|
|
// Every epilogue needs these two GmemTiledCopy{C,D} aliases.
|
|
// If you don't know what they should be, just use void.
|
|
using GmemTiledCopyC = void;
|
|
using GmemTiledCopyD = void;
|
|
|
|
using GatherC = GatherC_;
|
|
using ScatterD = ScatterD_;
|
|
|
|
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
|
|
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
|
|
|
|
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
|
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
|
|
|
struct SharedStorage { };
|
|
|
|
// Host side epilogue arguments
|
|
struct Arguments {
|
|
typename ThreadEpilogueOp::Params thread_params{};
|
|
ElementC const* ptr_C = nullptr;
|
|
StrideC dC{};
|
|
ElementD* ptr_D = nullptr;
|
|
StrideD dD{};
|
|
GatherC gather_C{};
|
|
ScatterD scatter_D{};
|
|
};
|
|
|
|
// Device side epilogue params
|
|
using Params = Arguments;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
template <class ProblemShape>
|
|
static constexpr Params
|
|
to_underlying_arguments(
|
|
[[maybe_unused]] ProblemShape const& _,
|
|
Arguments const& args,
|
|
[[maybe_unused]] void* workspace) {
|
|
return args;
|
|
}
|
|
|
|
template<class ProblemShape>
|
|
CUTLASS_HOST_DEVICE static bool
|
|
can_implement(
|
|
[[maybe_unused]] ProblemShape const& problem_shape,
|
|
[[maybe_unused]] Arguments const& args) {
|
|
return true;
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
EpilogueGatherScatter(Params const& params_) : params(params_) { }
|
|
|
|
template<
|
|
class ProblemShapeMNKL,
|
|
class BlockShapeMNK,
|
|
class BlockCoordMNKL,
|
|
class FrgEngine, class FrgLayout,
|
|
class TiledMma,
|
|
class ResidueMNK
|
|
>
|
|
CUTLASS_DEVICE void
|
|
operator()(
|
|
ProblemShapeMNKL problem_shape_mnkl,
|
|
BlockShapeMNK blk_shape_MNK,
|
|
BlockCoordMNKL blk_coord_mnkl,
|
|
cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
|
|
TiledMma tiled_mma,
|
|
ResidueMNK residue_mnk,
|
|
int thread_idx,
|
|
char* smem_buf)
|
|
{
|
|
using namespace cute;
|
|
using X = Underscore;
|
|
|
|
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
|
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
|
|
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
|
|
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
|
|
|
|
(void) smem_buf;
|
|
ThreadEpilogueOp epilogue_op{params.thread_params};
|
|
|
|
// Separate out problem shape for convenience
|
|
auto M = get<0>(problem_shape_mnkl);
|
|
auto N = get<1>(problem_shape_mnkl);
|
|
auto L = get<3>(problem_shape_mnkl);
|
|
|
|
auto stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC);
|
|
auto stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD);
|
|
|
|
// Represent the full output tensor
|
|
Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l)
|
|
Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l)
|
|
|
|
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
|
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
|
|
|
|
// Slice to get the tile this CTA is responsible for
|
|
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
|
|
Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
|
Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
|
|
|
|
// Partition source and destination tiles to match the accumulator partitioning
|
|
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
|
|
Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N)
|
|
Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N)
|
|
|
|
static_assert(is_static<FrgLayout>::value, "Accumulator layout must be static");
|
|
CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD),
|
|
"Source and destination must have the same number of elements.");
|
|
CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators),
|
|
"Accumulator count must have the same destination element count.");
|
|
|
|
// Make an identity coordinate tensor for predicating our output MN tile
|
|
auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
|
|
Tensor tCcD = thr_mma.partition_C(cD);
|
|
|
|
// source is needed
|
|
if (epilogue_op.is_source_needed()) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 0; i < size(accumulators); ++i) {
|
|
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
|
|
tCgD(i) = epilogue_op(accumulators(i), tCgC(i));
|
|
}
|
|
}
|
|
}
|
|
// source is not needed, avoid load
|
|
else {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 0; i < size(accumulators); ++i) {
|
|
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
|
|
tCgD(i) = epilogue_op(accumulators(i));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private:
|
|
Params params;
|
|
};
|
|
|
|
} // namespace cutlass::epilogue::collective
|
|
|