cutlass/examples/cute/tutorial/sgemm_2.cu

524 lines
19 KiB
Plaintext

/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cstdlib>
#include <cstdio>
#include <cassert>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/helper_cuda.hpp"
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class TiledCopyA,
class TB, class BStride, class BSmemLayout, class TiledCopyB,
class TC, class CStride, class CSmemLayout, class TiledMma,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(TiledMma{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
TC * C, CStride dC, CSmemLayout , TiledMma mma,
Alpha alpha, Beta beta)
{
using namespace cute;
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
//
// Full and Tiled Tensors
//
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
// Get the appropriate blocks for this thread block
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Shared memory buffers
__shared__ TA smemA[cosize_v<ASmemLayout>];
__shared__ TB smemB[cosize_v<BSmemLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
//
// Partition the copying of A and B tiles across the threads
//
// TUTORIAL: Example of partitioning via a TiledCopy
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)
// Allocate registers same shape/layout as partitioned data
Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K)
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K)
// Allocate registers same shape/layout as partitioned data
Tensor tBrB = make_fragment_like(tBsB); // (CPY,CPY_N,CPY_K)
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tArA)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tArA)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBrB)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBrB)); // CPY_K
// Copy gmem to rmem for k_tile=0
copy(copy_a, tAgA(_,_,_,0), tArA);
copy(copy_b, tBgB(_,_,_,0), tBrB);
//
// Define A/B partitioning and C accumulators
//
// TUTORIAL: Example of partitioning via a TiledMMA
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
// Clear the accumulators
clear(tCrC);
#if 0
if(thread0()) {
print(" mA : "); print( mA); print("\n");
print(" gA : "); print( gA); print("\n");
print(" sA : "); print( sA); print("\n");
print("tAgA : "); print(tAgA); print("\n");
print("tAsA : "); print(tAsA); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mB : "); print( mB); print("\n");
print(" gB : "); print( gB); print("\n");
print(" sB : "); print( sB); print("\n");
print("tBgB : "); print(tBgB); print("\n");
print("tBsB : "); print(tBsB); print("\n");
print("tArA : "); print(tArA); print("\n");
}
#endif
#if 0
if(thread0()) {
print(" mC : "); print( mC); print("\n");
print(" gC : "); print( gC); print("\n");
print("tCsA : "); print(tCsA); print("\n");
print("tCsB : "); print(tCsB); print("\n");
print("tCgC : "); print(tCgC); print("\n");
print("tCrC : "); print(tCrC); print("\n");
}
#endif
#if 1
// TUTORIAL: Example of an inner loop that pipelines compute with reads
// from global memory by staging through register and shared memory.
// Data is read from global to registers, then to shared via the TiledCopy partitions
// gemm(.) operates on the shared memory directly via the TiledMMA partitions
auto K_TILE_MAX = size<3>(tAgA);
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{
// Copy rmem to smem with tA|tB thread-partitioned tensors
__syncthreads(); // Wait for all threads to consume smem
copy(tArA, tAsA);
copy(tBrB, tBsB);
__syncthreads(); // Wait for all threads to consume smem
// Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors
int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile;
copy(copy_a, tAgA(_,_,_,k_tile_next), tArA);
copy(copy_b, tBgB(_,_,_,k_tile_next), tBrB);
// TUTORIAL: The above call to copy(copy_a, tAgA(_,_,_,k_tile_next), tArA) is equivalent to
// CUTE_UNROLL
// for (int k = 0; k < size<1>(tCsA); ++k) {
// CUTE_UNROLL
// for (int m = 0; m < size<0>(tCrC); ++m) {
// copy_a.call(tAgA(_,m,k), tArA(_,m,k);
// }
// }
// Compute gemm on mma-partitioned smem
gemm(mma, tCsA, tCsB, tCrC);
// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
// CUTE_UNROLL
// for (int k = 0; k < size<1>(tCsA); ++k) {
// CUTE_UNROLL
// for (int m = 0; m < size<0>(tCrC); ++m) {
// CUTE_UNROLL
// for (int n = 0; n < size<1>(tCrC); ++n) {
// mma.call(tCsA(_,m,k), tCsB(_,n,k), tCrC(_,m,n);
// }
// }
// }
}
#endif
//
// Epilogue
//
axpby(alpha, tCrC, beta, tCgC);
}
// Setup params for a NT GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_nt(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
// TUTORIAL: Construct TiledCopy with a particular Copy_Atom to use and
// define the partitioning pattern to apply.
// Each thread will (try to) copy 4x1 elements of type TA using 128-bit copy.
// Use 32x8 of these threads.
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TA>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 m-major
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TB>{},
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 n-major
// TUTORIAL: Construct TiledMMA with a particular MMA_Atom to use and
// define the partitioning pattern to apply.
// Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread.
// Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads.
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 UniversalFMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
// Setup params for a TN GEMM
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm_tn(int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
// TUTORIAL: Construct TiledCopy to define the Copy_Atom to use and the
// partitioning pattern to apply.
// Each thread will copy 1x1 elements of type TA.
// Use 32x8 of these threads arranged in k-major.
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<TA>, TA>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<TB>, TB>{},
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
Layout<Shape< _1,_1>>{}); // Val layout 1x1
// TUTORIAL: Construct TiledMMA to define the MMA_Atom to use and the
// partitioning pattern to apply.
// Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread.
// Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads.
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
#if 0
print(copyA);
print(copyB);
print(mmaC);
#endif
#if 0
print_latex(copyA);
print_latex(copyB);
print_latex(mmaC);
#endif
dim3 dimBlock(size(mmaC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, copyA,
B, dB, sB, copyB,
C, dC, sC, mmaC,
alpha, beta);
}
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gemm(char transA, char transB, int m, int n, int k,
Alpha alpha,
TA const* A, int ldA,
TB const* B, int ldB,
Beta beta,
TC * C, int ldC,
cudaStream_t stream = 0)
{
if (transA == 'N' && transB == 'T') {
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
} else
if (transA == 'T' && transB == 'N') {
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
}
assert(false && "Not implemented");
}
int main(int argc, char** argv)
{
int m = 5120;
if (argc >= 2)
sscanf(argv[1], "%d", &m);
int n = 5120;
if (argc >= 3)
sscanf(argv[2], "%d", &n);
int k = 4096;
if (argc >= 4)
sscanf(argv[3], "%d", &k);
char transA = 'N';
if (argc >= 5)
sscanf(argv[4], "%c", &transA);
char transB = 'T';
if (argc >= 6)
sscanf(argv[5], "%c", &transB);
using TA = float;
using TB = float;
using TC = float;
using TI = float;
TI alpha = 1.0;
TI beta = 0.0;
std::cout << "M = " << m << std::endl;
std::cout << "N = " << n << std::endl;
std::cout << "K = " << k << std::endl;
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
cute::device_init(0);
thrust::host_vector<TA> h_A(m*k);
thrust::host_vector<TB> h_B(n*k);
thrust::host_vector<TC> h_C(m*n);
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
thrust::device_vector<TA> d_A = h_A;
thrust::device_vector<TB> d_B = h_B;
thrust::device_vector<TC> d_C = h_C;
double gflops = (2.0*m*n*k) * 1e-9;
const int timing_iterations = 100;
GPU_Clock timer;
int ldA = 0, ldB = 0, ldC = m;
if (transA == 'N') {
ldA = m;
} else if (transA == 'T') {
ldA = k;
} else {
assert(false);
}
if (transB == 'N') {
ldB = k;
} else if (transB == 'T') {
ldB = n;
} else {
assert(false);
}
// Run once
d_C = h_C;
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
CUTE_CHECK_LAST();
thrust::host_vector<TC> cute_result = d_C;
// Timing iterations
timer.start();
for (int i = 0; i < timing_iterations; ++i) {
gemm(transA, transB, m, n, k,
alpha,
d_A.data().get(), ldA,
d_B.data().get(), ldB,
beta,
d_C.data().get(), ldC);
}
double cute_time = timer.seconds() / timing_iterations;
CUTE_CHECK_LAST();
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
return 0;
}