215 lines
9.5 KiB
C++
215 lines
9.5 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include <cute/config.hpp>
|
|
|
|
#include <cute/atom/copy_atom.hpp>
|
|
|
|
#include <cute/algorithm/copy.hpp>
|
|
|
|
#include <cute/tensor.hpp>
|
|
#include <cute/tensor_predicate.hpp>
|
|
|
|
namespace cute
|
|
{
|
|
|
|
// cooperative_copy<NumThreads, MaxVecBits>(thr_idx, src, dst)
|
|
// Use NumThreads to copy src to dst with element vectorization up to MaxVecBits.
|
|
// @pre 0 <= @a tid < NumThreads
|
|
// @pre Tensors @a src and @a dst are aligned up to MaxVecBits.
|
|
//
|
|
template <uint32_t NumThreads, uint32_t MaxVecBits,
|
|
class SrcEngine, class SrcLayout,
|
|
class DstEngine, class DstLayout>
|
|
CUTE_HOST_DEVICE
|
|
void
|
|
cooperative_copy(uint32_t const& tid,
|
|
Tensor<SrcEngine, SrcLayout> const& src,
|
|
Tensor<DstEngine, DstLayout> & dst)
|
|
{
|
|
// Assumes the shapes are static, can generalize
|
|
CUTE_STATIC_ASSERT_V(size(src) == size(dst));
|
|
// Assumes the types are the same, can generalize
|
|
static_assert(sizeof_bits_v<typename SrcEngine::value_type> == sizeof_bits_v<typename DstEngine::value_type>);
|
|
static_assert(MaxVecBits == sizeof_bits_v<typename SrcEngine::value_type> ||
|
|
MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128,
|
|
"Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance.");
|
|
// Check that the tensors are likely shared across threads: either gmem or smem
|
|
static_assert((is_gmem<SrcEngine>::value || is_smem<SrcEngine>::value),
|
|
"cooperative_copy expects shared gmem or smem source tensor.");
|
|
static_assert((is_gmem<DstEngine>::value || is_smem<DstEngine>::value),
|
|
"cooperative_copy expects shared gmem or smem destination tensor.");
|
|
|
|
// Precondition on tid in DEBUG
|
|
assert(tid < NumThreads);
|
|
|
|
// Fallback - slow path, naive copy, vectorization disabled
|
|
if constexpr(size(SrcLayout{}) % NumThreads != 0) {
|
|
int index = static_cast<int>(tid);
|
|
CUTE_UNROLL
|
|
for(int i = 0; i < ceil_div(size(SrcLayout{}), NumThreads); i++) {
|
|
if(index < size(SrcLayout{})) {
|
|
dst[index] = src[index];
|
|
}
|
|
index += NumThreads;
|
|
}
|
|
} else {
|
|
// Fast path with vectorization
|
|
|
|
// Precondition on pointer alignment in DEBUG
|
|
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(src.data())));
|
|
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(dst.data())));
|
|
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
|
|
|
|
//
|
|
// Determine val+thr vectorization based on src/dst size and number of threads
|
|
// NOTE: This heuristic promotes parallelization over vectorization
|
|
//
|
|
|
|
// The number of elements that can be vectorized in values
|
|
constexpr int common_elem = decltype(max_common_vector(src, dst))::value;
|
|
constexpr int common_bits = common_elem * elem_bits;
|
|
constexpr int total_elem = decltype(size(src))::value;
|
|
constexpr int total_bits = total_elem * elem_bits;
|
|
static_assert(total_bits % NumThreads == 0);
|
|
constexpr int total_bits_per_thr = total_bits / NumThreads;
|
|
// If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits
|
|
constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr);
|
|
|
|
// Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits
|
|
constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast<int>(MaxVecBits));
|
|
// Convert back to number of elements, safe_div
|
|
static_assert((vec_bits % elem_bits) == 0);
|
|
constexpr int vec_elem = vec_bits / elem_bits;
|
|
|
|
// Use only part of threads if there's not enough work for all threads
|
|
constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0)
|
|
? NumThreads
|
|
: (total_elem / vec_elem);
|
|
static_assert(vec_thrs <= NumThreads);
|
|
|
|
// The common layout of the two tensors that can be vectorized over threads
|
|
// vidx -> coord
|
|
auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()),
|
|
get_nonswizzle_portion(dst.layout()));
|
|
|
|
// Scale up the common_layout to cover the entire tensors
|
|
// vidx -> coord
|
|
auto full_perm = tile_to_shape(make_layout(common_layout), size(src));
|
|
|
|
// Create the Tiler
|
|
// ((vid,tid),iter)
|
|
auto layout_vt = logical_divide(full_perm, Layout<Shape<Int<vec_elem>, Int<vec_thrs>>>{});
|
|
|
|
// Apply and slice
|
|
Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_);
|
|
Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_);
|
|
|
|
// Should account for vec_bits < 8 and/or vec_elem <= 1
|
|
// And also account for subbyte types, which could cause race conditions
|
|
// Want to ENFORCE sufficient vectorization in those cases
|
|
static_assert((vec_bits >= 8), "No support for subbyte copying");
|
|
using VecType = uint_bit_t<vec_bits>;
|
|
|
|
#if 0
|
|
if (thread0()) {
|
|
print(" "); print("cooperative_copy -- vec\n");
|
|
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
|
|
print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n");
|
|
print(" "); print("src: "); print(src); print("\n");
|
|
print(" "); print("dst: "); print(dst); print("\n");
|
|
print(" "); print("common_layout: "); print(common_layout); print("\n");
|
|
print(" "); print("full_perm: "); print(full_perm); print("\n");
|
|
print(" "); print("Used vector: "); print(vec_elem); print("\n");
|
|
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
|
|
print(" "); print("layout_vt: "); print(layout_vt); print("\n");
|
|
print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n");
|
|
print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n");
|
|
print(" "); print("src_v: "); print(src_v); print("\n");
|
|
print(" "); print("dst_v: "); print(dst_v); print("\n");
|
|
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
|
|
print(" "); print("recast<VecType const>(dst_v): "); print(recast<VecType const>(dst_v)); print("\n");
|
|
}
|
|
#ifdef __CUDA_ARCH__
|
|
__syncthreads();
|
|
#endif
|
|
#endif
|
|
|
|
// If we're using all threads (static) or the tid is in in-range (dynamic)
|
|
if (vec_thrs >= NumThreads or tid < vec_thrs) {
|
|
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
|
|
}
|
|
}
|
|
}
|
|
|
|
template <uint32_t NumThreads,
|
|
class SrcEngine, class SrcLayout,
|
|
class DstEngine, class DstLayout>
|
|
CUTE_HOST_DEVICE
|
|
void
|
|
cooperative_copy(uint32_t const& tid,
|
|
Tensor<SrcEngine, SrcLayout> const& src,
|
|
Tensor<DstEngine, DstLayout> & dst)
|
|
{
|
|
constexpr uint32_t MaxVecBits = sizeof_bits_v<typename SrcEngine::value_type>;
|
|
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
|
|
}
|
|
|
|
// Accept mutable temporaries
|
|
template <uint32_t NumThreads,
|
|
class SrcEngine, class SrcLayout,
|
|
class DstEngine, class DstLayout>
|
|
CUTE_HOST_DEVICE
|
|
void
|
|
cooperative_copy(uint32_t const& tid,
|
|
Tensor<SrcEngine, SrcLayout> const& src,
|
|
Tensor<DstEngine, DstLayout> && dst)
|
|
{
|
|
return cooperative_copy<NumThreads>(tid, src, dst);
|
|
}
|
|
|
|
// Accept mutable temporaries
|
|
template <uint32_t NumThreads,
|
|
uint32_t MaxVecBits,
|
|
class SrcEngine, class SrcLayout,
|
|
class DstEngine, class DstLayout>
|
|
CUTE_HOST_DEVICE
|
|
void
|
|
cooperative_copy(uint32_t const& tid,
|
|
Tensor<SrcEngine, SrcLayout> const& src,
|
|
Tensor<DstEngine, DstLayout> && dst)
|
|
{
|
|
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
|
|
}
|
|
|
|
} // end namespace cute
|