cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h

632 lines
22 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
File copied from "cutlass/epilogue/threadblock/epilogue.h"
then modified to:
(1) load 2 source fragments at the same time (pipelining)
(2) support reading from a different dtype
(3) pass the row id to the OutputOp if it takes it
(see MemoryEfficientAttentionNormalize)
Note that in general the fragment passed to the OutputOp could
span multiple rows but it does not happen with the configurations we have
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <typename Op>
struct ApplyEpilogueOp {
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum,
typename Op::FragmentOutput const& source) {
return output_op(accum, source);
}
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum) {
return output_op(accum);
}
};
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator
template <
typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
///< gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator writing output tensors
typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting
///< accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing
///< accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading
///< from SMEM
typename OutputOp_, ///< Output operator
typename Padding_, ///< Padding added to SMEM allocation to avoid bank
///< conflicts (concept: MatrixShape)
int FragmentsPerPartition =
1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue op is
///< large
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
typename OutputTileSourceIterator_ =
OutputTileIterator_ ///< Tile iterator reading tensors
>
class EpiloguePipelined : public EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition> {
public:
using Base = EpilogueBase<
Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using OutputTileSourceIterator = OutputTileSourceIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename Base::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
using ElementSource = typename OutputTileSourceIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef =
typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType = Array<
typename OutputTileIterator::Element,
OutputTileIterator::kElementsPerAccess>;
using SourceAccessType = Array<
typename OutputTileSourceIterator::Element,
OutputTileSourceIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<
typename WarpTileIterator::Element,
OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1
? Base::kFragmentsPerIteration
: kPartitionsK;
static int constexpr kSmemPointerOffset =
Base::SharedStorage::StorageShape::kCount / kSmemTiles;
public:
static_assert(
OutputTileSourceIterator::Fragment::kElements ==
OutputTileIterator::Fragment::kElements,
"Mismatch between input tile and output tile iterator (kElements)");
static_assert(
OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations,
"Mismatch between input tile and output tile iterator (kIterations)");
static_assert(
SharedLoadIterator::Fragment::kElements ==
OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(
OutputTileIterator::kElementsPerAccess,
"OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(
!(OutputTileIterator::Fragment::kElements %
OutputTileIterator::kElementsPerAccess),
"Divisibility");
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
public:
/// Constructor
CUTLASS_DEVICE
EpiloguePipelined(
typename Base::SharedStorage& shared_storage, ///< Shared storage object
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
shared_load_iterator_(shared_storage.reference(), thread_idx) {}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators, ///< Complete warp-level accumulator tile
OutputTileSourceIterator
source_iterator) { ///< Threadblock tile coordinate in GEMM (in units
///< of threadblock tiles)
if (!output_op.is_source_needed()) {
compute_source_not_needed_(output_op, destination_iterator, accumulators);
} else {
compute_source_needed_(
output_op, destination_iterator, accumulators, source_iterator);
}
}
CUTLASS_DEVICE
void operator()(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators) { ///< Complete warp-level accumulator tile
compute_source_not_needed_(output_op, destination_iterator, accumulators);
}
private:
template <class Seq>
struct acc2smem_source_not_needed;
template <size_t... Seq>
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(
AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator& warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(
kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
CUTLASS_DEVICE
static void push(
size_t pos,
AccumulatorFragmentIterator const& iterator_begin,
WarpTileIterator& warp_tile_iterator) {
int dummy[] = {
(pos == (Seq * Base::kFragmentsPerIteration)) &&
(helper<Seq * Base::kFragmentsPerIteration>(
iterator_begin, warp_tile_iterator),
0)...};
CUTLASS_UNUSED(dummy[0]);
}
};
static_assert(
kPartitionsK == 1 || Base::kFragmentsPerIteration == 1,
"One of these must be exactly 1.");
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators ///< Complete warp-level accumulator tile
) {
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll( \
IterationsUnroll \
? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \
: 1)
for (int iter = 0; iter < OutputTileIterator::kIterations;
iter += Base::kFragmentsPerIteration) {
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_not_needed<cutlass::make_index_sequence<
OutputTileIterator::kIterations / Base::kFragmentsPerIteration>>::
push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename SharedLoadIterator::Fragment
aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (p < Base::kFragmentsPerIteration - 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
} else if (kPartitionsK > 1) {
plus<typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(
aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset(
(1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed_(
destination_iterator.thread_start_row(),
output_fragment,
output_op,
aligned_accum_fragment[0]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
if (Base::kFragmentsPerIteration > 1) {
shared_load_iterator_.add_pointer_offset(
kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
}
template <class Seq>
struct acc2smem_source_needed;
template <size_t... Seq>
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(
AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator& warp_tile_iterator) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(
size_t pos,
AccumulatorFragmentIterator const& iterator_begin,
WarpTileIterator& warp_tile_iterator) {
int dummy[] = {
(pos == Seq) &&
(helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators, ///< Complete warp-level accumulator tile
OutputTileSourceIterator
source_iterator ///< Threadblock tile coordinate in GEMM (in units of
///< threadblock tiles)
) {
typename OutputTileSourceIterator::Fragment source_fragment[2];
source_fragment[0].clear();
source_iterator.load(source_fragment[0]);
++source_iterator;
source_fragment[1].clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
if (iter > 0) {
__syncthreads();
}
//
// Load the source for next iteration (pipelining)
//
if (iter + 1 < OutputTileIterator::kIterations) {
source_iterator.load(source_fragment[(iter + 1) % 2]);
}
++source_iterator;
acc2smem_source_needed<
cutlass::make_index_sequence<OutputTileIterator::kIterations>>::
push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment
aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the
// k-slices
if (kPartitionsK > 1) {
plus<typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(
aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset(
(1 - kPartitionsK) * kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_(
destination_iterator.thread_start_row(),
output_fragment,
output_op,
aligned_accum_fragment[0],
source_fragment[iter % 2]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
int begin_row,
typename OutputTileIterator::Fragment& output_fragment,
OutputOp const& output_op, ///< Output operator
typename SharedLoadIterator::Fragment const& aligned_accum_fragment,
typename OutputTileSourceIterator::Fragment const& source_fragment) {
OutputAccessType* output_frag_ptr =
reinterpret_cast<OutputAccessType*>(&output_fragment);
AccumulatorAccessType const* compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
SourceAccessType const* source_frag_ptr =
reinterpret_cast<SourceAccessType const*>(&source_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
output_op,
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
compute_frag_ptr[i],
source_frag_ptr[i]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
int begin_row,
typename OutputTileIterator::Fragment& output_fragment,
OutputOp const& output_op, ///< Output operator
typename SharedLoadIterator::Fragment const& aligned_accum_fragment) {
OutputAccessType* output_frag_ptr =
reinterpret_cast<OutputAccessType*>(&output_fragment);
AccumulatorAccessType const* compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
output_op,
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
compute_frag_ptr[i]);
}
}
// This should be constexpr, but it's only supported on c++14
static int CUTLASS_HOST_DEVICE getRowOffset(int i) {
using ThreadMap = typename OutputTileIterator::ThreadMap;
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
int frag_row_idx =
(row +
ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
int frag_idx = ThreadMap::kElementsPerAccess *
(frag_row_idx * ThreadMap::Iterations::kColumn + column);
if (i < frag_idx + ThreadMap::kElementsPerAccess) {
return row_offset;
}
}
}
}
}
return -1;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////