1324 lines
51 KiB
C++
1324 lines
51 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
#ifdef HAS_PYTORCH
|
|
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
|
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
|
#endif
|
|
|
|
#include <curand_kernel.h>
|
|
#include <cmath>
|
|
#include <vector>
|
|
|
|
#include "cutlass/fast_math.h"
|
|
#include "cutlass/gemm/gemm.h"
|
|
#include "cutlass/layout/matrix.h"
|
|
#include "cutlass/layout/vector.h"
|
|
#include "cutlass/matrix.h"
|
|
#include "cutlass/numeric_types.h"
|
|
#include "cutlass/tensor_ref.h"
|
|
|
|
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
|
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
|
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
|
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
|
#include "cutlass/gemm/kernel/default_gemm.h"
|
|
#include "cutlass/gemm/threadblock/default_mma.h"
|
|
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
|
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
|
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
|
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
|
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
|
#include "cutlass/matrix_shape.h"
|
|
#include "cutlass/platform/platform.h"
|
|
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
|
#include "debug_utils.h"
|
|
#include "epilogue/epilogue_pipelined.h"
|
|
#include "epilogue/epilogue_rescale_output.h"
|
|
#include "gemm/custom_mma.h"
|
|
#include "gemm/find_default_mma.h"
|
|
#include "gemm/mma_from_smem.h"
|
|
#include "gemm_kernel_utils.h"
|
|
#include "transform/tile_smem_loader.h"
|
|
|
|
#include <inttypes.h>
|
|
|
|
using namespace gemm_kernel_utils;
|
|
|
|
namespace {
|
|
template <typename scalar_t, typename Arch>
|
|
constexpr int getWarpsPerSmFw() {
|
|
return (
|
|
Arch::kMinComputeCapability >= 80 &&
|
|
!cutlass::platform::is_same<scalar_t, float>::value
|
|
? 16
|
|
: 12);
|
|
}
|
|
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
|
// source: https://stackoverflow.com/a/51549250
|
|
return (value >= 0)
|
|
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
|
: __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
|
}
|
|
} // namespace
|
|
|
|
// If ToBatchHookType_ is supplied other than this default (which is
|
|
// never the case in the xformers library) then the user is
|
|
// defining the logic which each block uses to find its data to work on,
|
|
// with the advance_to_batch function with the following signature.
|
|
// It should return false if there is no work to do for this block.
|
|
// In general this will not work with saving for backward due to fixed layout
|
|
// for logsumexp and incompatible rngs for dropout, so is likely only useful for
|
|
// custom inference.
|
|
struct DefaultToBatchHook {
|
|
template <typename Params>
|
|
CUTLASS_DEVICE static bool advance_to_batch(
|
|
Params&,
|
|
int64_t& /* q_start */,
|
|
int64_t& /* k_start */) {
|
|
return true;
|
|
}
|
|
};
|
|
|
|
template <
|
|
// The datatype of Q/K/V
|
|
typename scalar_t_,
|
|
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
|
|
typename ArchTag,
|
|
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
|
bool isAligned_,
|
|
int kQueriesPerBlock_,
|
|
int kKeysPerBlock_,
|
|
// upperbound on `max(value.shape[-1], query.shape[-1])`
|
|
int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
|
|
// This is quite slower on V100 for some reason
|
|
// Set to false if you know at compile-time you will never need dropout
|
|
bool kSupportsDropout_ = true,
|
|
bool kSupportsBias_ = true,
|
|
typename ToBatchHookType_ = DefaultToBatchHook>
|
|
struct AttentionKernel {
|
|
enum CustomMaskType {
|
|
NoCustomMask = 0,
|
|
CausalFromTopLeft = 1,
|
|
CausalFromBottomRight = 2,
|
|
NumCustomMaskTypes,
|
|
};
|
|
|
|
using scalar_t = scalar_t_;
|
|
using accum_t = float;
|
|
using lse_scalar_t = float;
|
|
using output_t = scalar_t;
|
|
// Accumulator between 2 iterations
|
|
// Using `accum_t` improves perf on f16 at the cost of
|
|
// numerical errors
|
|
using output_accum_t = accum_t;
|
|
static constexpr bool kSupportsDropout = kSupportsDropout_;
|
|
static constexpr bool kSupportsBias = kSupportsBias_;
|
|
static constexpr int kKeysPerBlock = kKeysPerBlock_;
|
|
static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
|
|
static constexpr int kMaxK = kMaxK_;
|
|
static constexpr bool kIsAligned = isAligned_;
|
|
static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
|
|
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
|
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
|
|
static constexpr bool kPreloadV =
|
|
ArchTag::kMinComputeCapability >= 80 && kIsHalf;
|
|
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
|
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
|
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
|
|
|
static_assert(kQueriesPerBlock % 32 == 0, "");
|
|
static_assert(kKeysPerBlock % 32 == 0, "");
|
|
static constexpr int kNumWarpsPerBlock =
|
|
kQueriesPerBlock * kKeysPerBlock / (32 * 32);
|
|
static constexpr int kWarpSize = 32;
|
|
|
|
// Launch bounds
|
|
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
|
static constexpr int kMinBlocksPerSm =
|
|
getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
|
|
|
struct Params {
|
|
// Input tensors
|
|
scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
|
|
scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
|
|
scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
|
|
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
|
|
int32_t* seqstart_q_ptr = nullptr;
|
|
int32_t* seqstart_k_ptr = nullptr;
|
|
|
|
int32_t* seqlen_k_ptr = nullptr;
|
|
uint32_t causal_diagonal_offset = 0;
|
|
|
|
// Output tensors
|
|
output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
|
|
// [num_queries, num_heads, head_dim_value]
|
|
output_accum_t* output_accum_ptr = nullptr;
|
|
// [num_heads, num_queries] - can be null
|
|
lse_scalar_t* logsumexp_ptr = nullptr;
|
|
|
|
// Scale
|
|
accum_t scale = 0.0;
|
|
|
|
// Dimensions/strides
|
|
int32_t head_dim = 0;
|
|
int32_t head_dim_value = 0;
|
|
int32_t num_queries = 0;
|
|
int32_t num_keys = 0;
|
|
int32_t num_keys_absolute = 0;
|
|
|
|
uint8_t custom_mask_type = NoCustomMask;
|
|
|
|
int32_t q_strideM = 0;
|
|
int32_t k_strideM = 0;
|
|
int32_t v_strideM = 0;
|
|
int32_t bias_strideM = 0;
|
|
|
|
int32_t o_strideM = 0;
|
|
|
|
// Everything below is only used in `advance_to_block`
|
|
// and shouldn't use registers
|
|
int32_t q_strideH = 0;
|
|
int32_t k_strideH = 0;
|
|
int32_t v_strideH = 0;
|
|
int64_t bias_strideH = 0;
|
|
|
|
int64_t q_strideB = 0;
|
|
int64_t k_strideB = 0;
|
|
int64_t v_strideB = 0;
|
|
int64_t bias_strideB = 0;
|
|
|
|
int32_t num_batches = 0;
|
|
int32_t num_heads = 0;
|
|
|
|
// dropout
|
|
bool use_dropout = false;
|
|
unsigned long long dropout_batch_head_rng_offset = 0;
|
|
float dropout_prob = 0.0f;
|
|
#ifdef HAS_PYTORCH
|
|
at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
|
|
#endif
|
|
|
|
// Moves pointers to what we should process
|
|
// Returns "false" if there is no work to do
|
|
CUTLASS_DEVICE bool advance_to_block() {
|
|
auto batch_id = blockIdx.z;
|
|
auto head_id = blockIdx.y;
|
|
auto query_start = blockIdx.x * kQueriesPerBlock;
|
|
|
|
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
|
|
|
|
if (kSupportsDropout) {
|
|
dropout_batch_head_rng_offset =
|
|
batch_id * num_heads * num_queries * num_keys +
|
|
head_id * num_queries * num_keys;
|
|
}
|
|
|
|
int64_t q_start = 0, k_start = 0;
|
|
// Advance to current batch - in case of different sequence lengths
|
|
constexpr bool kToBatchHook =
|
|
!cutlass::platform::is_same<ToBatchHookType_, DefaultToBatchHook>::
|
|
value;
|
|
if (kToBatchHook) {
|
|
// Call out to a custom implementation.
|
|
if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) {
|
|
return false;
|
|
}
|
|
} else if (seqstart_q_ptr != nullptr) {
|
|
assert(seqstart_k_ptr != nullptr);
|
|
seqstart_q_ptr += batch_id;
|
|
|
|
q_start = seqstart_q_ptr[0];
|
|
int64_t q_next_start = seqstart_q_ptr[1];
|
|
int64_t k_end;
|
|
seqstart_k_ptr += batch_id;
|
|
|
|
if (seqlen_k_ptr) {
|
|
k_start = seqstart_k_ptr[0];
|
|
k_end = k_start + seqlen_k_ptr[batch_id];
|
|
} else {
|
|
k_start = seqstart_k_ptr[0];
|
|
k_end = seqstart_k_ptr[1];
|
|
}
|
|
|
|
num_queries = q_next_start - q_start;
|
|
num_keys = k_end - k_start;
|
|
|
|
if (query_start >= num_queries) {
|
|
return false;
|
|
}
|
|
} else {
|
|
query_ptr += batch_id * q_strideB;
|
|
key_ptr += batch_id * k_strideB;
|
|
value_ptr += batch_id * v_strideB;
|
|
output_ptr += int64_t(batch_id * num_queries) * o_strideM;
|
|
if (output_accum_ptr != nullptr) {
|
|
output_accum_ptr +=
|
|
int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
|
|
}
|
|
q_start = 0;
|
|
k_start = 0;
|
|
}
|
|
|
|
// Advance to the current batch / head / query_start
|
|
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
|
|
key_ptr += k_start * k_strideM + head_id * k_strideH;
|
|
|
|
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
|
output_ptr +=
|
|
int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
|
|
|
|
if (kSupportsBias && attn_bias_ptr != nullptr) {
|
|
attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
|
|
}
|
|
if (output_accum_ptr != nullptr) {
|
|
output_accum_ptr +=
|
|
int64_t(q_start + query_start) * (head_dim_value * num_heads) +
|
|
head_id * head_dim_value;
|
|
} else {
|
|
// Accumulate directly in the destination buffer (eg for f32)
|
|
output_accum_ptr = (accum_t*)output_ptr;
|
|
}
|
|
|
|
if (logsumexp_ptr != nullptr) {
|
|
// lse[batch_id, head_id, query_start]
|
|
logsumexp_ptr +=
|
|
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
|
|
}
|
|
|
|
// Custom masking
|
|
if (custom_mask_type == CausalFromBottomRight) {
|
|
causal_diagonal_offset = num_keys - num_queries;
|
|
}
|
|
// We use num_keys_absolute to index into the rng_state
|
|
// We need this index to match between forward and backwards
|
|
num_keys_absolute = num_keys;
|
|
if (custom_mask_type == CausalFromTopLeft ||
|
|
custom_mask_type == CausalFromBottomRight) {
|
|
// the bottom row of the current block is query_start + kQueriesPerBlock
|
|
// the last active key is then query_start + causal_diagonal_offset +
|
|
// kQueriesPerBlock so num_keys is the min between actual num_keys and
|
|
// this to avoid extra computations
|
|
num_keys = cutlass::fast_min(
|
|
int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock),
|
|
num_keys);
|
|
}
|
|
|
|
num_queries -= query_start;
|
|
num_batches = 0; // no longer used after
|
|
|
|
// If num_queries == 1, and there is only one key head we're wasting
|
|
// 15/16th of tensor core compute In that case :
|
|
// - we only launch kernels for head_id % kQueriesPerBlock == 0
|
|
// - we iterate over heads instead of queries (strideM = strideH)
|
|
if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) {
|
|
if (head_id % kQueriesPerBlock != 0)
|
|
return false;
|
|
q_strideM = q_strideH;
|
|
num_queries = num_heads;
|
|
num_heads = 1; // unused but here for intent
|
|
// remove causal since n_query = 1
|
|
// otherwise, offset would change with head !
|
|
custom_mask_type = NoCustomMask;
|
|
o_strideM = head_dim_value;
|
|
}
|
|
|
|
// Make sure the compiler knows these variables are the same on all
|
|
// the threads of the warp.
|
|
// Only worth doing if they could have been modified above.
|
|
query_ptr = warp_uniform(query_ptr);
|
|
key_ptr = warp_uniform(key_ptr);
|
|
value_ptr = warp_uniform(value_ptr);
|
|
if (kSupportsBias) {
|
|
attn_bias_ptr = warp_uniform(attn_bias_ptr);
|
|
}
|
|
output_ptr = warp_uniform(output_ptr);
|
|
output_accum_ptr = warp_uniform(output_accum_ptr);
|
|
logsumexp_ptr = warp_uniform(logsumexp_ptr);
|
|
num_queries = warp_uniform(num_queries);
|
|
num_keys = warp_uniform(num_keys);
|
|
num_heads = warp_uniform(num_heads);
|
|
o_strideM = warp_uniform(o_strideM);
|
|
custom_mask_type = warp_uniform(custom_mask_type);
|
|
return true;
|
|
}
|
|
|
|
__host__ dim3 getBlocksGrid() const {
|
|
return dim3(
|
|
ceil_div(num_queries, (int32_t)kQueriesPerBlock),
|
|
num_heads,
|
|
num_batches);
|
|
}
|
|
|
|
__host__ dim3 getThreadsGrid() const {
|
|
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
|
|
}
|
|
};
|
|
|
|
struct MM0 {
|
|
/*
|
|
In this first matmul, we compute a block of `Q @ K.T`.
|
|
While the calculation result is still hot in registers, we update
|
|
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
|
|
into a shared-memory ("AccumulatorSharedStorage") that is used later as
|
|
operand A for the second matmul (see MM1)
|
|
*/
|
|
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
|
|
|
using OpClass = typename GemmType::OpClass;
|
|
using DefaultConfig =
|
|
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
|
OpClass,
|
|
ArchTag,
|
|
scalar_t,
|
|
scalar_t,
|
|
scalar_t, // ElementC
|
|
accum_t // ElementAccumulator
|
|
>;
|
|
static constexpr int kAlignmentA =
|
|
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
|
|
static constexpr int kAlignmentB =
|
|
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
|
using ThreadblockShape = cutlass::gemm::
|
|
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
|
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
|
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
|
|
scalar_t, // ElementA,
|
|
cutlass::layout::RowMajor, // LayoutA,
|
|
kAlignmentA,
|
|
scalar_t, // ElementB,
|
|
cutlass::layout::ColumnMajor, // LayoutB,
|
|
kAlignmentB,
|
|
accum_t,
|
|
cutlass::layout::RowMajor, // LayoutC,
|
|
OpClass,
|
|
ArchTag, // ArchTag
|
|
ThreadblockShape, // ThreadblockShape
|
|
WarpShape, // WarpShape
|
|
typename GemmType::InstructionShape, // InstructionShape
|
|
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
|
? 4
|
|
: DefaultConfig::kStages,
|
|
typename GemmType::Operator // Operator
|
|
>::DefaultMma;
|
|
using MmaCore = typename DefaultMma::MmaCore;
|
|
using IteratorA = typename DefaultMma::IteratorA;
|
|
using IteratorB = typename DefaultMma::IteratorB;
|
|
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
|
|
using Mma = typename cutlass::platform::conditional<
|
|
kSingleValueIteration,
|
|
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
|
|
DefaultThreadblockMma>::type;
|
|
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
|
typename Mma::Operator::IteratorC,
|
|
accum_t,
|
|
kWarpSize>::Iterator;
|
|
static_assert(
|
|
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
|
|
MmaCore::WarpCount::kK ==
|
|
kNumWarpsPerBlock,
|
|
"");
|
|
|
|
// used for efficient load of bias tile Bij from global to shared memory
|
|
using BiasLoader = TileSmemLoader<
|
|
scalar_t,
|
|
cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
|
|
MmaCore::kThreads,
|
|
// input restriction: kv_len has to be a multiple of this value
|
|
128 / cutlass::sizeof_bits<scalar_t>::value>;
|
|
|
|
// Epilogue to store to shared-memory in a format that we can use later for
|
|
// the second matmul
|
|
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
|
|
typename Mma::Operator::IteratorC,
|
|
typename Mma::Operator,
|
|
scalar_t,
|
|
WarpShape,
|
|
ThreadblockShape>;
|
|
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
|
|
};
|
|
|
|
struct MM1 {
|
|
/**
|
|
Second matmul: perform `attn @ V` where `attn` is the attention (not
|
|
normalized) and stored in shared memory
|
|
*/
|
|
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
|
|
|
using OpClass = typename GemmType::OpClass;
|
|
using DefaultConfig =
|
|
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
|
OpClass,
|
|
ArchTag,
|
|
scalar_t,
|
|
scalar_t,
|
|
output_accum_t, // ElementC
|
|
accum_t // ElementAccumulator
|
|
>;
|
|
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
|
|
static constexpr int kAlignmentB =
|
|
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
|
using ThreadblockShape = cutlass::gemm::
|
|
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
|
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
|
using InstructionShape = typename GemmType::InstructionShape;
|
|
|
|
using LayoutB = cutlass::layout::RowMajor;
|
|
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
|
|
scalar_t, // ElementA,
|
|
cutlass::layout::RowMajor, // LayoutA,
|
|
kAlignmentA,
|
|
scalar_t, // ElementB,
|
|
LayoutB, // LayoutB,
|
|
kAlignmentB,
|
|
output_accum_t,
|
|
cutlass::layout::RowMajor, // LayoutC,
|
|
accum_t,
|
|
OpClass,
|
|
ArchTag,
|
|
ThreadblockShape,
|
|
WarpShape,
|
|
typename GemmType::InstructionShape,
|
|
typename DefaultConfig::EpilogueOutputOp,
|
|
void, // ThreadblockSwizzle - not used
|
|
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
|
? 4
|
|
: DefaultConfig::kStages,
|
|
false, // SplitKSerial
|
|
typename GemmType::Operator>;
|
|
|
|
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
|
DefaultWarpIteratorAFromSharedMemory<
|
|
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
|
|
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
|
|
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
|
|
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
|
using DefaultMmaFromSmem =
|
|
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
|
typename DefaultGemm::Mma,
|
|
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
|
|
WarpIteratorA,
|
|
false>; // kScaleOperandA
|
|
using Mma = typename DefaultMmaFromSmem::Mma;
|
|
using IteratorB = typename Mma::IteratorB;
|
|
using WarpCount = typename Mma::WarpCount;
|
|
static_assert(
|
|
WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
|
|
"");
|
|
|
|
using DefaultEpilogue = typename DefaultGemm::Epilogue;
|
|
using OutputTileIterator =
|
|
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
|
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
|
output_t>;
|
|
using OutputTileIteratorAccum =
|
|
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
|
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
|
output_accum_t>;
|
|
};
|
|
|
|
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
|
|
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
|
|
static constexpr int64_t kAlignmentV = 1;
|
|
|
|
// Shared storage - depends on kernel params
|
|
struct ScalingCoefs {
|
|
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
|
|
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
|
|
cutlass::Array<accum_t, kQueriesPerBlock> mi;
|
|
cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
|
|
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
|
|
addition_storage;
|
|
};
|
|
|
|
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
|
struct SharedStorageAfterMM0 {
|
|
// Everything here might be overwritten during MM0
|
|
union {
|
|
typename MM0::BiasLoader::SmemTile bias;
|
|
typename MM0::AccumulatorSharedStorage si;
|
|
};
|
|
typename MM1::Mma::SharedStorage mm1;
|
|
};
|
|
|
|
union {
|
|
typename MM0::Mma::SharedStorage mm0;
|
|
SharedStorageAfterMM0 after_mm0;
|
|
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
|
};
|
|
|
|
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
|
epilogue_shared_storage() {
|
|
return epilogue;
|
|
}
|
|
};
|
|
|
|
struct SharedStorageEpilogueInLoop : ScalingCoefs {
|
|
struct SharedStorageAfterMM0 {
|
|
// Everything here might be overwritten during MM0
|
|
union {
|
|
typename MM0::BiasLoader::SmemTile bias;
|
|
typename MM0::AccumulatorSharedStorage si;
|
|
};
|
|
typename MM1::Mma::SharedStorage mm1;
|
|
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
|
};
|
|
|
|
union {
|
|
typename MM0::Mma::SharedStorage mm0;
|
|
SharedStorageAfterMM0 after_mm0;
|
|
};
|
|
|
|
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
|
epilogue_shared_storage() {
|
|
return after_mm0.epilogue;
|
|
}
|
|
};
|
|
|
|
using SharedStorage = typename cutlass::platform::conditional<
|
|
kSingleValueIteration || kKeepOutputInRF,
|
|
SharedStorageEpilogueAtEnd,
|
|
SharedStorageEpilogueInLoop>::type;
|
|
|
|
static bool __host__ check_supported(Params const& p) {
|
|
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
|
|
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
|
|
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
|
|
if (kSupportsBias) {
|
|
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
|
|
XFORMERS_CHECK(
|
|
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
|
|
"attn_bias is not correctly aligned (strideB)");
|
|
XFORMERS_CHECK(
|
|
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
|
|
"attn_bias is not correctly aligned (strideH)");
|
|
XFORMERS_CHECK(
|
|
p.bias_strideM % kAlignmentQ == 0,
|
|
"attn_bias is not correctly aligned");
|
|
}
|
|
XFORMERS_CHECK(
|
|
p.q_strideM % kAlignmentQ == 0,
|
|
"query is not correctly aligned (strideM)");
|
|
XFORMERS_CHECK(
|
|
p.k_strideM % kAlignmentK == 0,
|
|
"key is not correctly aligned (strideM)");
|
|
XFORMERS_CHECK(
|
|
p.v_strideM % kAlignmentV == 0,
|
|
"value is not correctly aligned (strideM)");
|
|
XFORMERS_CHECK(
|
|
p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0,
|
|
"query is not correctly aligned (strideH)");
|
|
XFORMERS_CHECK(
|
|
p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0,
|
|
"key is not correctly aligned (strideH)");
|
|
XFORMERS_CHECK(
|
|
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
|
|
"value is not correctly aligned (strideH)");
|
|
XFORMERS_CHECK(
|
|
p.custom_mask_type < NumCustomMaskTypes,
|
|
"invalid value for `custom_mask_type`");
|
|
return true;
|
|
}
|
|
|
|
static void CUTLASS_DEVICE attention_kernel(Params& p) {
|
|
// In this block, we will only ever:
|
|
// - read query[query_start:query_end, :]
|
|
// - write to output[query_start:query_end, :]
|
|
|
|
extern __shared__ char smem_buffer[];
|
|
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
|
|
auto& m_prime = shared_storage.m_prime;
|
|
auto& s_prime = shared_storage.s_prime;
|
|
auto& mi = shared_storage.mi;
|
|
auto& out_rescale = shared_storage.out_rescale;
|
|
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
|
|
|
|
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
|
if (thread_id() < kQueriesPerBlock) {
|
|
s_prime[thread_id()] = accum_t(0);
|
|
out_rescale[thread_id()] = accum_t(1.0);
|
|
m_prime[thread_id()] =
|
|
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
|
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
|
}
|
|
typename MM1::Mma::FragmentC accum_o;
|
|
accum_o.clear();
|
|
|
|
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
|
|
using OutputTileIterator = typename MM1::OutputTileIterator;
|
|
return OutputTileIterator(
|
|
typename OutputTileIterator::Params{(int32_t)p.o_strideM},
|
|
p.output_ptr,
|
|
typename OutputTileIterator::TensorCoord{
|
|
p.num_queries, p.head_dim_value},
|
|
thread_id(),
|
|
{0, col});
|
|
};
|
|
|
|
auto createOutputAccumIter = [&](int col) ->
|
|
typename MM1::OutputTileIteratorAccum {
|
|
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
|
|
return OutputTileIteratorAccum(
|
|
typename OutputTileIteratorAccum::Params{
|
|
(int32_t)(p.head_dim_value * p.num_heads)},
|
|
p.output_accum_ptr,
|
|
typename OutputTileIteratorAccum::TensorCoord{
|
|
p.num_queries, p.head_dim_value},
|
|
thread_id(),
|
|
{0, col});
|
|
};
|
|
|
|
#ifdef HAS_PYTORCH
|
|
curandStatePhilox4_32_10_t curand_state_init;
|
|
if (kSupportsDropout && p.use_dropout) {
|
|
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
|
|
|
|
// each element of the attention matrix P with shape
|
|
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
|
|
// offset in RNG sequence. we initialize the RNG state with offset that
|
|
// starts at the beginning of a (n_queries, n_keys) matrix for this
|
|
// block's batch_id and head_id
|
|
// initializing rng state is very expensive, so we run once per kernel,
|
|
// rather than once per iteration. each iteration takes a copy of the
|
|
// initialized RNG state and offsets it as needed.
|
|
curand_init(
|
|
std::get<0>(seeds),
|
|
0,
|
|
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
|
|
&curand_state_init);
|
|
}
|
|
#endif
|
|
|
|
// Iterate through keys
|
|
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
|
|
iter_key_start += kKeysPerBlock) {
|
|
int32_t problem_size_0_m =
|
|
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
|
|
int32_t problem_size_0_n = cutlass::fast_min(
|
|
int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
|
|
int32_t const& problem_size_0_k = p.head_dim;
|
|
int32_t const& problem_size_1_n = p.head_dim_value;
|
|
int32_t const& problem_size_1_k = problem_size_0_n;
|
|
|
|
auto prologueV = [&](int blockN) {
|
|
typename MM1::Mma::IteratorB iterator_V(
|
|
typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)},
|
|
p.value_ptr + iter_key_start * p.v_strideM,
|
|
{problem_size_1_k, problem_size_1_n},
|
|
thread_id(),
|
|
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
|
MM1::Mma::prologue(
|
|
shared_storage.after_mm0.mm1,
|
|
iterator_V,
|
|
thread_id(),
|
|
problem_size_1_k);
|
|
};
|
|
|
|
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
|
|
// updated from end of prev iter
|
|
//
|
|
// MATMUL: Q.K_t
|
|
//
|
|
// Computes the block-matrix product of:
|
|
// (a) query[query_start:query_end, :]
|
|
// with
|
|
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
|
|
// and stores that into `shared_storage.si`
|
|
//
|
|
|
|
// Compute threadblock location
|
|
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
|
|
|
|
cutlass::MatrixCoord tb_offset_A{
|
|
tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
|
|
|
|
cutlass::MatrixCoord tb_offset_B{
|
|
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
|
|
|
|
// Construct iterators to A and B operands
|
|
typename MM0::IteratorA iterator_A(
|
|
typename MM0::IteratorA::Params(
|
|
typename MM0::MmaCore::LayoutA(p.q_strideM)),
|
|
p.query_ptr,
|
|
{problem_size_0_m, problem_size_0_k},
|
|
thread_id(),
|
|
tb_offset_A);
|
|
|
|
typename MM0::IteratorB iterator_B(
|
|
typename MM0::IteratorB::Params(
|
|
typename MM0::MmaCore::LayoutB(p.k_strideM)),
|
|
p.key_ptr + iter_key_start * p.k_strideM,
|
|
{problem_size_0_k, problem_size_0_n},
|
|
thread_id(),
|
|
tb_offset_B);
|
|
|
|
auto my_warp_id = warp_uniform(warp_id());
|
|
auto my_lane_id = lane_id();
|
|
|
|
// Construct thread-scoped matrix multiply
|
|
typename MM0::Mma mma(
|
|
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
|
|
|
|
typename MM0::Mma::FragmentC accum;
|
|
|
|
accum.clear();
|
|
|
|
auto gemm_k_iterations =
|
|
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
|
|
|
|
// Compute threadblock-scoped matrix multiply-add
|
|
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
|
|
__syncthreads();
|
|
|
|
if (kPreloadV) {
|
|
prologueV(0);
|
|
} else {
|
|
MM1::Mma::drain_cp_asyncs();
|
|
}
|
|
|
|
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
|
iteratorC_tile_offset = {
|
|
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
|
|
(my_warp_id % MM0::Mma::WarpCount::kM),
|
|
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
|
|
(my_warp_id / MM0::Mma::WarpCount::kM)};
|
|
|
|
// multiply by scaling factor
|
|
if (kSupportsBias) {
|
|
accum =
|
|
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
|
|
}
|
|
|
|
// apply attention bias if applicable
|
|
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
|
|
// load bias tile Bij into shared memory
|
|
typename MM0::BiasLoader::GmemTileIterator bias_iter(
|
|
{cutlass::layout::RowMajor(p.bias_strideM)},
|
|
// attn_bias_pointer points to matrix of size (n_queries, n_keys)
|
|
// for the relevant batch_id and head_id
|
|
p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start,
|
|
{problem_size_0_m, problem_size_0_n},
|
|
thread_id());
|
|
cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
|
|
shared_storage.after_mm0.bias.data(),
|
|
cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
|
|
typename MM0::BiasLoader::SmemTileIterator smem_tile_iter(
|
|
bias_tensor_ref, thread_id());
|
|
MM0::BiasLoader::load(bias_iter, smem_tile_iter);
|
|
|
|
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
|
|
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
|
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
|
MM0::AccumLambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) {},
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
|
|
accum[idx] += bias_tensor_ref.at({accum_m, accum_n});
|
|
}
|
|
},
|
|
[&](int accum_m) {});
|
|
}
|
|
|
|
// Mask out last if causal
|
|
// This is only needed if upper-right corner of current query / key block
|
|
// intersects the mask Coordinates of upper-right corner of current block
|
|
// is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The
|
|
// first masked element is x = y + offset -> query_start + offset There is
|
|
// intersection (and we need to mask) if min(iter_key_start +
|
|
// kKeysPerBlock, num_keys)) >= query_start + offset
|
|
if (p.custom_mask_type &&
|
|
cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >=
|
|
(query_start + p.causal_diagonal_offset)) {
|
|
auto query_start = blockIdx.x * kQueriesPerBlock;
|
|
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
|
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
|
int32_t last_col;
|
|
MM0::AccumLambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) {
|
|
// last absolute col is (last absolute query + offset)
|
|
// last local col is (last absolute query + offset -
|
|
// iter_key_start)
|
|
last_col = query_start + accum_m + p.causal_diagonal_offset -
|
|
iter_key_start;
|
|
},
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
if (accum_n > last_col) {
|
|
accum[idx] =
|
|
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
|
}
|
|
},
|
|
[&](int accum_m) {});
|
|
}
|
|
// Update `mi` from accum stored in registers
|
|
// Also does accum[i] <- exp(accum[i] - mi)
|
|
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
|
|
accum_o,
|
|
accum,
|
|
mi,
|
|
m_prime,
|
|
s_prime,
|
|
out_rescale,
|
|
shared_storage.addition_storage,
|
|
my_lane_id,
|
|
thread_id(),
|
|
my_warp_id,
|
|
p.num_keys - iter_key_start,
|
|
iter_key_start == 0,
|
|
iteratorC_tile_offset,
|
|
kSupportsBias ? 1.0f : p.scale);
|
|
|
|
// Output results to shared-memory
|
|
int warp_idx_mn_0 = my_warp_id %
|
|
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
|
|
auto output_tile_coords = cutlass::MatrixCoord{
|
|
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
|
|
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
|
|
|
|
MM0::B2bGemm::accumToSmem(
|
|
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
|
|
|
|
__syncthreads();
|
|
|
|
#ifdef HAS_PYTORCH
|
|
// apply dropout (if applicable) after we've written Pij to smem.
|
|
// dropout is applied by multiplying each element of Pij by:
|
|
// - 0 with probability dropout_p
|
|
// - 1 / (1 - dropout_p) with probability 1 - dropout_p
|
|
//
|
|
// for backward purposes we want to be able to map each element of the
|
|
// attention matrix to the same random uniform number as the one we used
|
|
// in forward, without needing to use the same iteration order or having
|
|
// to store the dropout matrix. its possible to do this in registers but
|
|
// it ends up being very slow because each thread having noncontiguous
|
|
// strips of the Pij tile means we have to skip around a lot, and also
|
|
// have to generate a single random number at a time
|
|
if (kSupportsDropout && p.use_dropout) {
|
|
auto si = shared_storage.after_mm0.si.accum_ref();
|
|
// each thread handles a contiguous sequence of elements from Sij, all
|
|
// coming from the same row. the reason they have to come from the same
|
|
// row is that the sampling random numbers from a contiguous random
|
|
// number sequence is much more efficient than jumping around, and the
|
|
// linear offset of each element of S (the global matrix) maps to an
|
|
// offset in a random number sequence. for S, the end of a row and the
|
|
// beginning of the next have adjacent offsets, but for Sij, this is not
|
|
// necessarily the case.
|
|
const int num_threads = blockDim.x * blockDim.y * blockDim.z;
|
|
const int threads_per_row =
|
|
cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n);
|
|
const int elts_per_thread = cutlass::round_nearest(
|
|
cutlass::ceil_div(problem_size_0_n, threads_per_row), 4);
|
|
|
|
const int thread_i = thread_id() / threads_per_row;
|
|
const int thread_start_j =
|
|
(thread_id() % threads_per_row) * elts_per_thread;
|
|
|
|
if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
|
|
curandStatePhilox4_32_10_t curand_state = curand_state_init;
|
|
skipahead(
|
|
static_cast<unsigned long long>(
|
|
(query_start + thread_i) * p.num_keys_absolute +
|
|
(iter_key_start + thread_start_j)),
|
|
&curand_state);
|
|
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
|
|
|
|
// apply dropout scaling to elements this thread is responsible for,
|
|
// in chunks of 4
|
|
for (int sij_start_col_idx = thread_start_j; sij_start_col_idx <
|
|
cutlass::fast_min(thread_start_j + elts_per_thread,
|
|
problem_size_0_n);
|
|
sij_start_col_idx += 4) {
|
|
const float4 rand_uniform_quad = curand_uniform4(&curand_state);
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
|
|
si.at({thread_i, sij_start_col_idx + quad_idx}) *=
|
|
static_cast<scalar_t>(
|
|
dropout_scale *
|
|
((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
|
|
}
|
|
}
|
|
}
|
|
__syncthreads(); // p.use_dropout should have same value kernel-wide
|
|
}
|
|
#endif
|
|
|
|
//
|
|
// MATMUL: Attn . V
|
|
// Run the matmul `attn @ V` for a block of attn and V.
|
|
// `attn` is read from shared memory (in `shared_storage_si`)
|
|
// `V` is read from global memory (with iterator_B)
|
|
//
|
|
|
|
const int64_t nBlockN = kSingleValueIteration
|
|
? 1
|
|
: ceil_div(
|
|
(int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
|
|
for (int blockN = 0; blockN < nBlockN; ++blockN) {
|
|
int gemm_k_iterations =
|
|
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
|
|
|
|
// Compute threadblock-scoped matrix multiply-add and store it in accum
|
|
// (in registers)
|
|
if (!kPreloadV) {
|
|
__syncthreads(); // we share shmem between mma and epilogue
|
|
}
|
|
|
|
typename MM1::Mma::IteratorB iterator_V(
|
|
typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)},
|
|
p.value_ptr + iter_key_start * p.v_strideM,
|
|
{problem_size_1_k, problem_size_1_n},
|
|
thread_id(),
|
|
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
|
typename MM1::Mma mma_pv(
|
|
// operand A: Pij_dropped in shared memory
|
|
shared_storage.after_mm0.si.accum_ref(),
|
|
// operand B: shared memory staging area for Vj, which is loaded
|
|
// from global memory
|
|
shared_storage.after_mm0.mm1.operand_B_ref(),
|
|
(int)thread_id(),
|
|
(int)my_warp_id,
|
|
(int)my_lane_id);
|
|
mma_pv.set_prologue_done(kPreloadV);
|
|
if (!kKeepOutputInRF) {
|
|
accum_o.clear();
|
|
}
|
|
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
|
|
__syncthreads();
|
|
|
|
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
|
|
prologueV(blockN + 1);
|
|
}
|
|
|
|
if (!kKeepOutputInRF) {
|
|
MM1::Mma::drain_cp_asyncs();
|
|
DISPATCH_BOOL(
|
|
iter_key_start == 0, kIsFirst, ([&] {
|
|
DISPATCH_BOOL(
|
|
(iter_key_start + kKeysPerBlock) >= p.num_keys,
|
|
kIsLast,
|
|
([&] {
|
|
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
|
using DefaultOp =
|
|
typename MM1::DefaultConfig::EpilogueOutputOp;
|
|
using ElementCompute = typename DefaultOp::ElementCompute;
|
|
using EpilogueOutputOp = typename cutlass::epilogue::
|
|
thread::MemoryEfficientAttentionNormalize<
|
|
typename cutlass::platform::conditional<
|
|
kIsLast,
|
|
output_t,
|
|
output_accum_t>::type,
|
|
output_accum_t,
|
|
DefaultOp::kCount,
|
|
typename DefaultOp::ElementAccumulator,
|
|
ElementCompute,
|
|
kIsFirst,
|
|
kIsLast,
|
|
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
|
using Epilogue = typename cutlass::epilogue::threadblock::
|
|
EpiloguePipelined<
|
|
typename DefaultEpilogue::Shape,
|
|
typename MM1::Mma::Operator,
|
|
DefaultEpilogue::kPartitionsK,
|
|
typename cutlass::platform::conditional<
|
|
kIsLast,
|
|
typename MM1::OutputTileIterator,
|
|
typename MM1::OutputTileIteratorAccum>::type,
|
|
typename DefaultEpilogue::
|
|
AccumulatorFragmentIterator,
|
|
typename DefaultEpilogue::WarpTileIterator,
|
|
typename DefaultEpilogue::SharedLoadIterator,
|
|
EpilogueOutputOp,
|
|
typename DefaultEpilogue::Padding,
|
|
DefaultEpilogue::kFragmentsPerIteration,
|
|
true, // IterationsUnroll
|
|
typename MM1::OutputTileIteratorAccum // Read
|
|
// iterator
|
|
>;
|
|
|
|
int col = blockN * MM1::Mma::Shape::kN;
|
|
auto source_iter = createOutputAccumIter(col);
|
|
auto dest_iter = call_conditional<
|
|
kIsLast,
|
|
decltype(createOutputIter),
|
|
decltype(createOutputAccumIter)>::
|
|
apply(createOutputIter, createOutputAccumIter, col);
|
|
EpilogueOutputOp rescale(s_prime, out_rescale);
|
|
Epilogue epilogue(
|
|
shared_storage.epilogue_shared_storage(),
|
|
thread_id(),
|
|
my_warp_id,
|
|
my_lane_id);
|
|
epilogue(rescale, dest_iter, accum_o, source_iter);
|
|
}));
|
|
}));
|
|
if (!kSingleValueIteration) {
|
|
__syncthreads();
|
|
}
|
|
}
|
|
}
|
|
__syncthreads(); // we modify `m_prime` after
|
|
}
|
|
|
|
if (kKeepOutputInRF) {
|
|
constexpr bool kIsFirst = true;
|
|
constexpr bool kIsLast = true;
|
|
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
|
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
|
|
using ElementCompute = typename DefaultOp::ElementCompute;
|
|
using EpilogueOutputOp =
|
|
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
|
|
output_t, // output
|
|
output_accum_t, // source
|
|
DefaultOp::kCount,
|
|
typename DefaultOp::ElementAccumulator, // accum
|
|
output_accum_t, // compute
|
|
kIsFirst,
|
|
kIsLast,
|
|
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
|
using Epilogue =
|
|
typename cutlass::epilogue::threadblock::EpiloguePipelined<
|
|
typename DefaultEpilogue::Shape,
|
|
typename MM1::Mma::Operator,
|
|
DefaultEpilogue::kPartitionsK,
|
|
typename MM1::OutputTileIterator, // destination
|
|
typename DefaultEpilogue::AccumulatorFragmentIterator,
|
|
typename DefaultEpilogue::WarpTileIterator,
|
|
typename DefaultEpilogue::SharedLoadIterator,
|
|
EpilogueOutputOp,
|
|
typename DefaultEpilogue::Padding,
|
|
DefaultEpilogue::kFragmentsPerIteration,
|
|
true, // IterationsUnroll
|
|
typename MM1::OutputTileIteratorAccum // source tile
|
|
>;
|
|
auto dest_iter = createOutputIter(0);
|
|
EpilogueOutputOp rescale(s_prime, out_rescale);
|
|
Epilogue epilogue(
|
|
shared_storage.epilogue_shared_storage(),
|
|
thread_id(),
|
|
warp_id(),
|
|
lane_id());
|
|
MM1::Mma::drain_cp_asyncs();
|
|
epilogue(rescale, dest_iter, accum_o);
|
|
}
|
|
|
|
// 7. Calculate logsumexp
|
|
// To make the backward easier, we pad logsumexp with `inf`
|
|
// this avoids a few bound checks, and is not more expensive during fwd
|
|
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
|
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
|
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
|
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
|
if (thread_id() < p.num_queries) {
|
|
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
|
|
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
|
} else if (thread_id() < lse_dim) {
|
|
p.logsumexp_ptr[thread_id()] =
|
|
cutlass::platform::numeric_limits<accum_t>::infinity();
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename WarpIteratorC>
|
|
CUTLASS_DEVICE static void iterative_softmax(
|
|
typename WarpIteratorC::Fragment& frag_o, // output so far
|
|
typename WarpIteratorC::Fragment& frag,
|
|
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
|
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
|
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
|
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
|
|
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
|
|
addition_storage,
|
|
int8_t lane_id,
|
|
int8_t thread_id,
|
|
int8_t warp_id,
|
|
int max_col,
|
|
bool is_first,
|
|
typename WarpIteratorC::TensorCoord const& tile_offset,
|
|
float scaling) {
|
|
/* Iterates on the accumulator and corresponding position on result matrix
|
|
|
|
(1) Update `mi[r]` to the max value of the row `r`
|
|
(2) In a second iteration do the following:
|
|
(a) accum <- exp(accum - mi)
|
|
(b) m_prime <- exp(m_prime - mi)
|
|
(c) s_prime <- s_prime * m_prime + sum(accum)
|
|
|
|
All of this is done on registers, before we store all of this
|
|
on shared memory for the next matmul with Value.
|
|
*/
|
|
using Fragment = typename WarpIteratorC::Fragment;
|
|
using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
|
WarpIteratorC,
|
|
accum_t,
|
|
kWarpSize>::Iterator;
|
|
// Convert to `accum_t` (rather than double)
|
|
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
|
|
|
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
|
|
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
|
|
|
|
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
|
|
|
auto lane_offset =
|
|
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
|
|
|
// First update `mi` to the max per-row
|
|
{
|
|
accum_t max;
|
|
LambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) {
|
|
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
|
},
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
if (accum_n < max_col) {
|
|
max = cutlass::fast_max(max, frag[idx]);
|
|
}
|
|
},
|
|
[&](int accum_m) {
|
|
// Having 4x atomicMax seems faster than reduce within warp
|
|
// first...
|
|
atomicMaxFloat(&mi[accum_m], max);
|
|
});
|
|
}
|
|
|
|
// Make sure we all share the update values for `mi`
|
|
__syncthreads();
|
|
|
|
// Doing this `exp` is quite expensive. Let's
|
|
// split it across the warps
|
|
bool restore_mi_to_minus_inf = false;
|
|
if (lane_id < kLinesPerWarp) {
|
|
int id = warp_id * kLinesPerWarp + lane_id;
|
|
auto m_prime_id = m_prime[id];
|
|
auto mi_id = mi[id];
|
|
bool changed = m_prime_id < mi_id; // `false` if both are -inf
|
|
if (changed) {
|
|
auto m_prime_exp = exp2f(m_prime_id - mi_id);
|
|
out_rescale[id] = m_prime_exp;
|
|
s_prime[id] *= m_prime_exp;
|
|
} else {
|
|
// Only when bias is enabled, it's possible that all the first values
|
|
// of attention are masked to `-inf`. In that case we want to avoid
|
|
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
|
|
if (kSupportsBias &&
|
|
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
|
|
restore_mi_to_minus_inf = true;
|
|
mi[id] = 0.0f;
|
|
}
|
|
out_rescale[id] = 1.0f;
|
|
}
|
|
}
|
|
__syncthreads(); // Update output fragments
|
|
if (kKeepOutputInRF && !is_first) {
|
|
accum_t line_rescale;
|
|
LambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
frag_o[idx] = frag_o[idx] * line_rescale;
|
|
},
|
|
[&](int accum_m) {});
|
|
}
|
|
// Update accum_m, accum_n, ...
|
|
{
|
|
accum_t mi_row, total_row;
|
|
LambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) { mi_row = mi[accum_m]; },
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
frag[idx] =
|
|
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
|
},
|
|
[&](int accum_m) {});
|
|
LambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) { total_row = 0.0; },
|
|
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
|
|
[&](int accum_m) {
|
|
if (LambdaIterator::reduceSameRow(
|
|
lane_id, total_row, [](accum_t a, accum_t b) {
|
|
return a + b;
|
|
})) {
|
|
// NOTE: we could atomically add `total_row` to `s_prime`, but
|
|
// it's faster (and deterministic) to avoid atomics here
|
|
addition_storage
|
|
[accum_m + kQueriesPerBlock * tile_offset.column()] =
|
|
total_row;
|
|
}
|
|
});
|
|
}
|
|
__syncthreads();
|
|
if (lane_id < kLinesPerWarp) {
|
|
int id = warp_id * kLinesPerWarp + lane_id;
|
|
accum_t total_row = s_prime[id];
|
|
if (restore_mi_to_minus_inf) {
|
|
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
|
|
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
|
} else {
|
|
m_prime[id] = mi[id];
|
|
}
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
|
total_row += addition_storage[id + kQueriesPerBlock * i];
|
|
}
|
|
s_prime[id] = total_row;
|
|
}
|
|
}
|
|
|
|
static CUTLASS_DEVICE int8_t lane_id() {
|
|
return threadIdx.x;
|
|
}
|
|
static CUTLASS_DEVICE int8_t warp_id() {
|
|
return threadIdx.y;
|
|
}
|
|
static CUTLASS_DEVICE int16_t thread_id() {
|
|
return threadIdx.x + threadIdx.y * blockDim.x;
|
|
}
|
|
};
|
|
|
|
template <typename AK>
|
|
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
|
attention_kernel_batched_impl(typename AK::Params p) {
|
|
if (!p.advance_to_block()) {
|
|
return;
|
|
}
|
|
AK::attention_kernel(p);
|
|
}
|
|
|
|
template <typename AK>
|
|
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
|
attention_kernel_batched(typename AK::Params params);
|