cutlass/test/unit/cute/cooperative_gemm_common.hpp

415 lines
18 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass_unit_test.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include <iostream>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
using namespace cute;
template<class ALayout,
class BLayout,
class CLayout,
class SMemALayout,
class SMemBLayout,
class SMemCLayout,
class SmemCopyOpA,
class SmemCopyOpB,
class SmemCopyOpC,
uint32_t ThreadBlockSize,
class TiledMma,
uint32_t CopyMaxVecBits,
class TA,
class TB,
class TC,
class Alpha,
class Beta,
class ALoadTransform,
class BLoadTransform,
class CLoadTransform,
class CStoreTransform>
__launch_bounds__(ThreadBlockSize) __global__ void
cooperative_gemm_kernel(TA const* a,
TB const* b,
TC* c,
TC* c_out,
Alpha const alpha,
Beta const beta,
ALoadTransform a_load_transform,
BLoadTransform b_load_transform,
CLoadTransform c_load_transform,
CStoreTransform c_store_transform)
{
using namespace cute;
Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), ALayout{});
Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), BLayout{});
Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), CLayout{});
Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), CLayout{});
constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8;
extern __shared__ float4 smem_buf[];
auto* smem_ptr = reinterpret_cast<unsigned char*>(smem_buf);
auto* smem_ptr_a = smem_ptr;
auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(SMemALayout {})), copy_max_vec_bytes);
auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(SMemBLayout {})), copy_max_vec_bytes);
Tensor s_a_tensor = make_tensor(make_smem_ptr<TA>(smem_ptr_a), SMemALayout{});
Tensor s_b_tensor = make_tensor(make_smem_ptr<TB>(smem_ptr_b), SMemBLayout{});
Tensor s_c_tensor = make_tensor(make_smem_ptr<TC>(smem_ptr_c), SMemCLayout{});
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_a_tensor, s_a_tensor);
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_b_tensor, s_b_tensor);
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_c_tensor, s_c_tensor);
cp_async_fence();
cp_async_wait<0>();
__syncthreads();
TiledMma tiled_mma;
cooperative_gemm<SmemCopyOpA, SmemCopyOpB, SmemCopyOpC>(
threadIdx.x, tiled_mma,
alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor,
a_load_transform, b_load_transform, c_load_transform, c_store_transform
);
__syncthreads();
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, s_c_tensor, g_c_out_tensor);
}
template<class ALayout, // logical shape (M, K)
class BLayout, // logical shape (N, K)
class CLayout, // logical shape (M, N)
class SMemALayout, // logical shape (M, K)
class SMemBLayout, // logical shape (N, K)
class SMemCLayout, // logical shape (M, N)
class SmemCopyOpA,
class SmemCopyOpB,
class SmemCopyOpC,
uint32_t ThreadBlockSize,
class TiledMma,
uint32_t CopyMaxVecBits,
class TA,
class TB,
class TC,
class ALoadTransform = cute::identity,
class BLoadTransform = cute::identity,
class CLoadTransform = cute::identity,
class CStoreTransform = cute::identity>
void test_cooperative_gemm(ALoadTransform const& a_load_transform = {},
BLoadTransform const& b_load_transform = {},
CLoadTransform const& c_load_transform = {},
CStoreTransform const& c_store_transform = {})
{
using gmem_a_layout_t = ALayout;
using gmem_b_layout_t = BLayout;
using gmem_c_layout_t = CLayout;
using smem_a_layout_t = SMemALayout;
using smem_b_layout_t = SMemBLayout;
using smem_c_layout_t = SMemCLayout;
static_assert(size<0>(gmem_a_layout_t{}) == size<0>(gmem_c_layout_t{})); // AM == CM
static_assert(size<0>(gmem_b_layout_t{}) == size<1>(gmem_c_layout_t{})); // BN == CN
static_assert(size<1>(gmem_a_layout_t{}) == size<1>(gmem_b_layout_t{})); // AK == BK
static_assert(size<0>(smem_a_layout_t{}) == size<0>(smem_c_layout_t{})); // AM == CM
static_assert(size<0>(smem_b_layout_t{}) == size<1>(smem_c_layout_t{})); // BN == CN
static_assert(size<1>(smem_a_layout_t{}) == size<1>(smem_b_layout_t{})); // AK == BK
static_assert(cute::size(gmem_a_layout_t {}) == cute::size(smem_a_layout_t {}));
static_assert(cute::size(gmem_b_layout_t {}) == cute::size(smem_b_layout_t {}));
static_assert(cute::size(gmem_c_layout_t {}) == cute::size(smem_c_layout_t {}));
#if 0
print(" "); print("gmem: "); print(gmem_layout_t{}); print("\n");
print(" "); print("smem: "); print(smem_layout_t{}); print("\n");
print(" "); print("threads: "); print(ThreadBlockSize); print("\n");
#endif
const auto alpha = static_cast<TC>(1.1);
const auto beta = static_cast<TC>(1.2);
thrust::host_vector<TA> h_a(cosize(gmem_a_layout_t{}));
thrust::host_vector<TB> h_b(cosize(gmem_b_layout_t{}));
thrust::host_vector<TC> h_c(cosize(gmem_c_layout_t{}));
thrust::host_vector<TC> h_c_out(cosize(gmem_c_layout_t{}));
auto h_a_tensor = make_tensor(h_a.data(), gmem_a_layout_t{});
auto h_b_tensor = make_tensor(h_b.data(), gmem_b_layout_t{});
auto h_c_tensor = make_tensor(h_c.data(), gmem_c_layout_t{});
size_t max_size = std::max<size_t>({static_cast<size_t>(size(gmem_a_layout_t {})),
static_cast<size_t>(size(gmem_b_layout_t {})),
static_cast<size_t>(size(gmem_c_layout_t {}))});
for (size_t i = 0; i < max_size; ++i) {
double di = static_cast<double>(i);
if(i < size(gmem_a_layout_t{})) {
h_a_tensor(i) = static_cast<TA>(di / size(gmem_a_layout_t{}));
}
if(i < size(gmem_b_layout_t{})) {
h_b_tensor(i) = static_cast<TA>(di / size(gmem_a_layout_t{}));
}
if(i < size(gmem_c_layout_t{})) {
h_c_tensor(i) = static_cast<TC>((di*di) / size(gmem_a_layout_t{}));
}
}
thrust::device_vector<TA> d_a(h_a);
thrust::device_vector<TB> d_b(h_b);
thrust::device_vector<TC> d_c(h_c);
thrust::device_vector<TC> d_c_out(h_c_out.size(), TC(float(-1)));
const size_t shared_memory_size =
(sizeof(TA) * h_a.size()) + (sizeof(TB) * h_b.size()) + (sizeof(TC) * h_c.size());
auto kernel = cooperative_gemm_kernel<
gmem_a_layout_t, gmem_b_layout_t, gmem_c_layout_t,
smem_a_layout_t, smem_b_layout_t, smem_c_layout_t,
SmemCopyOpA, SmemCopyOpB, SmemCopyOpC,
ThreadBlockSize, TiledMma, CopyMaxVecBits,
TA, TB, TC, decltype(alpha), decltype(beta),
ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform
>;
ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(shared_memory_size)), 0);
kernel<<<1, ThreadBlockSize, shared_memory_size>>>(
thrust::raw_pointer_cast(d_a.data()),
thrust::raw_pointer_cast(d_b.data()),
thrust::raw_pointer_cast(d_c.data()),
thrust::raw_pointer_cast(d_c_out.data()),
alpha,
beta,
a_load_transform,
b_load_transform,
c_load_transform,
c_store_transform
);
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
cudaError_t error = cudaGetLastError();
FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n";
}
thrust::host_vector<TC> h_c_ref(h_c.size(), static_cast<TC>(0.0));
auto h_c_ref_tensor = make_tensor(h_c_ref.data(), gmem_c_layout_t{});
// A * B
for (int k = 0; k < size<1>(h_a_tensor); k++) {
for (int m = 0; m < size<0>(h_a_tensor); m++) {
for (int n = 0; n < size<0>(h_b_tensor); n++) {
const auto a_value = a_load_transform(h_a_tensor(m, k));
const auto b_value = b_load_transform(h_b_tensor(n, k));
const auto a_value_fp64 = static_cast<double>(a_value);
const auto b_value_fp64 = static_cast<double>(b_value);
h_c_ref_tensor(m, n) += static_cast<TC>(a_value_fp64 * b_value_fp64);
}
}
}
// C = A*B + C
for (int i = 0; i < size(h_c_ref_tensor); i++) {
const auto ab_value_fp64 = static_cast<double>(h_c_ref_tensor(i));
const auto c_value_fp64 = static_cast<double>(c_load_transform(h_c_tensor(i)));
h_c_ref_tensor(i) = c_store_transform(static_cast<TC>(alpha * ab_value_fp64 + beta * c_value_fp64));
}
h_c_out = d_c_out;
auto h_c_out_tensor = make_tensor(h_c_out.data(), gmem_c_layout_t{});
for (int i = 0; i < size(h_c_ref_tensor); i++) {
double h_c_ref_i = h_c_ref_tensor(i);
double h_c_out_i = h_c_out_tensor(i);
double epsilon(0.1f);
double nonzero_floor(std::numeric_limits<double>::min());
bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor);
ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i;
}
}
template<uint32_t M,
uint32_t N,
uint32_t K,
uint32_t ThreadBlockSize,
class TiledMMAType,
uint32_t CopyMaxVecBits,
class TA,
class TB,
class TC,
class ALoadTransform = cute::identity,
class BLoadTransform = cute::identity,
class CLoadTransform = cute::identity,
class CStoreTransform = cute::identity>
void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {},
BLoadTransform const& b_load_transform = {},
CLoadTransform const& c_load_transform = {},
CStoreTransform const& c_store_transform = {})
{
using gmem_a_layout_t = decltype(make_layout(make_shape(Int<M> {}, Int<K> {})));
using gmem_b_layout_t = decltype(make_layout(make_shape(Int<N> {}, Int<K> {}), GenRowMajor{}));
using gmem_c_layout_t = decltype(make_layout(make_shape(Int<M> {}, Int<N> {})));
using smem_a_layout_t = decltype(make_layout(make_shape(Int<M> {}, Int<K> {})));
using smem_b_layout_t = decltype(make_layout(make_shape(Int<N> {}, Int<K> {}), GenRowMajor{}));
using smem_c_layout_t = decltype(make_layout(make_shape(Int<M> {}, Int<N> {})));
test_cooperative_gemm<gmem_a_layout_t,
gmem_b_layout_t,
gmem_c_layout_t,
smem_a_layout_t,
smem_b_layout_t,
smem_c_layout_t,
AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<TA>>,
AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<TB>>,
AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<TC>>,
ThreadBlockSize,
TiledMMAType,
CopyMaxVecBits,
TA,
TB,
TC>(a_load_transform, b_load_transform, c_load_transform, c_store_transform);
}
template<uint32_t M,
uint32_t N,
uint32_t K,
uint32_t ThreadBlockSize,
class TiledMMAType,
class T,
class ALoadTransform = cute::identity,
class BLoadTransform = cute::identity,
class CLoadTransform = cute::identity,
class CStoreTransform = cute::identity>
void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {},
BLoadTransform const& b_load_transform = {},
CLoadTransform const& c_load_transform = {},
CStoreTransform const& c_store_transform = {})
{
test_cooperative_gemm_col_major_layout<M, N, K, ThreadBlockSize, TiledMMAType, cute::sizeof_bits_v<T>, T, T, T>(
a_load_transform, b_load_transform, c_load_transform, c_store_transform);
}
template<class SMemAAtomLayout,
class SMemBAtomLayout,
class SMemCAtomLayout,
uint32_t M,
uint32_t N,
uint32_t K,
uint32_t ThreadBlockSize,
class TiledMMAType,
uint32_t CopyMaxVecBits,
class TA,
class TB,
class TC,
class ALoadTransform = cute::identity,
class BLoadTransform = cute::identity,
class CLoadTransform = cute::identity,
class CStoreTransform = cute::identity>
void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {},
BLoadTransform const& b_load_transform = {},
CLoadTransform const& c_load_transform = {},
CStoreTransform const& c_store_transform = {})
{
using gmem_a_layout_t = decltype(make_layout(make_shape(Int<M> {}, Int<K> {})));
using gmem_b_layout_t = decltype(make_layout(make_shape(Int<N> {}, Int<K> {}), GenRowMajor{}));
using gmem_c_layout_t = decltype(make_layout(make_shape(Int<M> {}, Int<N> {})));
using smem_a_atom_layout_t = SMemAAtomLayout;
using smem_a_layout_t = decltype(tile_to_shape(
smem_a_atom_layout_t{},
make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{})))
);
using smem_b_atom_layout_t = SMemBAtomLayout;
using smem_b_layout_t = decltype(tile_to_shape(
smem_b_atom_layout_t{},
make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{})))
);
using smem_c_atom_layout_t = SMemCAtomLayout;
using smem_c_layout_t = decltype(tile_to_shape(
smem_c_atom_layout_t{},
make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{})))
);
test_cooperative_gemm<gmem_a_layout_t,
gmem_b_layout_t,
gmem_c_layout_t,
smem_a_layout_t,
smem_b_layout_t,
smem_c_layout_t,
AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<TA>>,
AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<TB>>,
AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<TC>>,
ThreadBlockSize,
TiledMMAType,
CopyMaxVecBits,
TA,
TB,
TC>(a_load_transform, b_load_transform, c_load_transform, c_store_transform);
}
template<class SMemAAtomLayout,
class SMemBAtomLayout,
class SMemCAtomLayout,
uint32_t M,
uint32_t N,
uint32_t K,
uint32_t ThreadBlockSize,
class TiledMMAType,
class T,
class ALoadTransform = cute::identity,
class BLoadTransform = cute::identity,
class CLoadTransform = cute::identity,
class CStoreTransform = cute::identity>
void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {},
BLoadTransform const& b_load_transform = {},
CLoadTransform const& c_load_transform = {},
CStoreTransform const& c_store_transform = {})
{
test_cooperative_gemm_col_major_layout<SMemAAtomLayout,
SMemBAtomLayout,
SMemCAtomLayout,
M,
N,
K,
ThreadBlockSize,
TiledMMAType,
cute::sizeof_bits_v<T>,
T,
T,
T>(a_load_transform, b_load_transform, c_load_transform, c_store_transform);
}