426 lines
12 KiB
Plaintext
426 lines
12 KiB
Plaintext
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
/*
|
|
This example requires NVIDIA Maxwell GPU or beyond.
|
|
*/
|
|
|
|
// Standard Library includes
|
|
#include <iostream>
|
|
#include <sstream>
|
|
#include <vector>
|
|
|
|
// CUTLASS Includes
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/core_io.h"
|
|
#include "cutlass/functional.h"
|
|
#include "cutlass/layout/matrix.h"
|
|
#include "cutlass/gemm/warp/mma_simt.h"
|
|
#include "cutlass/epilogue/warp/fragment_iterator_simt.h"
|
|
#include "cutlass/epilogue/warp/tile_iterator_simt.h"
|
|
|
|
// CUTLASS Utility Includes
|
|
#include "cutlass/util/host_tensor.h"
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
|
|
#include "cutlass/util/reference/host/gemm.h"
|
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
|
#include "cutlass/util/reference/host/gemm_complex.h"
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Define the overal warp-level problem shape
|
|
int const kM = 14;
|
|
int const kN = 27;
|
|
int const kK = 17;
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Define a warp-level GEMM operator.
|
|
//
|
|
// This template could be part of the CUTLASS Template Library or implemented internally. This
|
|
// wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be
|
|
// instantiated in device code.
|
|
|
|
namespace cutlass {
|
|
namespace gemm {
|
|
namespace warp {
|
|
|
|
template <
|
|
typename Shape,
|
|
typename ElementA,
|
|
typename LayoutA,
|
|
typename ElementB,
|
|
typename LayoutB,
|
|
typename ElementC,
|
|
typename LayoutC,
|
|
typename ElementScalar
|
|
>
|
|
class GemmSimt {
|
|
public:
|
|
|
|
|
|
using Policy = cutlass::gemm::warp::MmaSimtPolicy<
|
|
cutlass::MatrixShape<4, 8>,
|
|
cutlass::layout::RowMajorInterleaved<2>,
|
|
cutlass::gemm::GemmShape<4, 4, 1>
|
|
>;
|
|
|
|
using MmaWarp = cutlass::gemm::warp::MmaSimt<
|
|
cutlass::gemm::GemmShape<16, 32, 8>,
|
|
float,
|
|
cutlass::layout::RowMajor,
|
|
float,
|
|
cutlass::layout::ColumnMajor,
|
|
float,
|
|
cutlass::layout::RowMajor,
|
|
Policy
|
|
>;
|
|
|
|
// Number of 'K groups'
|
|
int const kKgroups = Shape::kK;
|
|
|
|
using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt<
|
|
typename MmaWarp::Shape,
|
|
typename MmaWarp::ThreadMma,
|
|
layout::RowMajor, // SMEM layout
|
|
typename MmaWarp::Policy
|
|
>;
|
|
|
|
using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorSimtCanonical<
|
|
typename MmaWarp::Shape,
|
|
typename MmaWarp::ThreadMma,
|
|
float, // ElementAccumulator
|
|
layout::RowMajor, // SMEM layout
|
|
typename MmaWarp::Policy
|
|
>;
|
|
|
|
using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
|
|
using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
|
|
using TensorRefC = typename AccumulatorTileIterator::TensorRef;
|
|
|
|
public:
|
|
CUTLASS_HOST_DEVICE
|
|
GemmSimt() { }
|
|
|
|
CUTLASS_DEVICE
|
|
void operator()(
|
|
ElementScalar alpha,
|
|
TensorRefA ref_A,
|
|
TensorRefB ref_B,
|
|
ElementScalar beta,
|
|
TensorRefC ref_C,
|
|
TensorRefC ref_D,
|
|
int lane_id) const {
|
|
|
|
// Instantiate iterators pointing to slices of the A and B matrices in shared memory
|
|
typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id);
|
|
typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id);
|
|
|
|
// Instantiate and clear accumulator tile holding the C matrix
|
|
typename MmaWarp::FragmentC accum;
|
|
accum.clear();
|
|
|
|
// Instantiate the warp-level matrix multiply operator
|
|
MmaWarp mma_op;
|
|
|
|
// Instantiate fragments holding the slice of the matrix held by each warp
|
|
typename MmaWarp::FragmentA frag_A[2];
|
|
typename MmaWarp::FragmentB frag_B[2];
|
|
|
|
// Load fragments from shared memory
|
|
iter_A.load(frag_A[0]);
|
|
iter_B.load(frag_B[0]);
|
|
|
|
++iter_A;
|
|
++iter_B;
|
|
|
|
// Load fragments from shared memory
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int k = 0; k < kKgroups; ++k) {
|
|
|
|
// Load fragments from shared memory
|
|
iter_A.load(frag_A[(k + 1) % 2]);
|
|
iter_B.load(frag_B[(k + 1) % 2]);
|
|
|
|
++iter_A;
|
|
++iter_B;
|
|
|
|
// Compute the matrix multiply
|
|
mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum);
|
|
}
|
|
|
|
// Instantiate iterators
|
|
FragmentIterator accum_frag_it(accum);
|
|
AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id);
|
|
AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id);
|
|
|
|
// Define function objects for linear scaling operation
|
|
cutlass::multiplies<typename FragmentIterator::Fragment> mul_source;
|
|
cutlass::multiply_add<typename FragmentIterator::Fragment> mul_add_accumulator;
|
|
|
|
// Iterate over the epilogue components
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) {
|
|
|
|
// Define storage for slices of the accumulators
|
|
typename FragmentIterator::Fragment accum_fragment;
|
|
typename FragmentIterator::Fragment source_fragment;
|
|
|
|
// Select a slice of accumulators from the accumulator tile
|
|
accum_frag_it.load(accum_fragment);
|
|
++accum_frag_it;
|
|
|
|
// Load a corresponding slice from Shared memory
|
|
source_tile_it.load(source_fragment);
|
|
++source_tile_it;
|
|
|
|
// Compute linear scaling - alpha * AB + beta * C
|
|
source_fragment = mul_source(beta, source_fragment);
|
|
accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment);
|
|
|
|
// Store the result to shared memory
|
|
dest_tile_it.store(accum_fragment);
|
|
++dest_tile_it;
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace warp
|
|
} // namespace gemm
|
|
} // namespace cutlass
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held
|
|
// in Shared Memory.
|
|
__global__ void kernel(
|
|
float *D_gmem,
|
|
float alpha,
|
|
float const *A_gmem,
|
|
float const *B_gmem,
|
|
float beta,
|
|
float const *C_gmem) {
|
|
|
|
// Define several matrices in shared memory
|
|
__shared__ float A[kM][kK];
|
|
__shared__ float B[kN][kK];
|
|
__shared__ float C[kM][kN];
|
|
|
|
// Copy data into SMEM
|
|
if (threadIdx.x == 0) {
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
for (int m = 0; m < kM; ++m) {
|
|
for (int k = 0; k < kK; ++k) {
|
|
A[m][k] = A_gmem[m * kK + k];
|
|
}
|
|
}
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
for (int n = 0; n < kN; ++n) {
|
|
for (int k = 0; k < kK; ++k) {
|
|
B[n][k] = B_gmem[n * kK + k];
|
|
}
|
|
}
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
for (int m = 0; m < kM; ++m) {
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
for (int n = 0; n < kN; ++n) {
|
|
C[m][n] = C_gmem[m * kN + n];
|
|
}
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
//
|
|
// Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4),
|
|
// overall shape, data type of each operand, and layout of each operand.
|
|
//
|
|
|
|
using GemmSimt = cutlass::gemm::warp::GemmSimt<
|
|
cutlass::gemm::GemmShape<kM, kN, kK>,
|
|
float, // Data type of A elements
|
|
cutlass::layout::RowMajor, // Layout of A matrix
|
|
float, // Data type of B elements
|
|
cutlass::layout::ColumnMajor, // Layout of B matrix
|
|
float, // Data type of C elements
|
|
cutlass::layout::RowMajor, // Layout of C matrix
|
|
float // Scalar type of alpha and beta
|
|
>;
|
|
|
|
// Instantiate the GEMM operator
|
|
GemmSimt gemm;
|
|
|
|
// Execute the warp-level GEMM operation
|
|
gemm(
|
|
alpha,
|
|
{&A[0][0], kK},
|
|
{&B[0][0], kK},
|
|
beta,
|
|
{&C[0][0], kN},
|
|
{&C[0][0], kN},
|
|
threadIdx.x);
|
|
|
|
__syncthreads();
|
|
|
|
// Copy data into SMEM
|
|
if (threadIdx.x == 0) {
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
for (int m = 0; m < kM; ++m) {
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
for (int n = 0; n < kN; ++n) {
|
|
D_gmem[m * kN + n] = C[m][n];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
int main(int argc, const char *arg[]) {
|
|
|
|
cutlass::HostTensor<float, cutlass::layout::RowMajor> A({kM, kK});
|
|
cutlass::HostTensor<float, cutlass::layout::ColumnMajor> B({kK, kN});
|
|
cutlass::HostTensor<float, cutlass::layout::RowMajor> C({kM, kN});
|
|
cutlass::HostTensor<float, cutlass::layout::RowMajor> D({kM, kN});
|
|
|
|
uint64_t seed = 2020;
|
|
float max = 8;
|
|
float min = -8;
|
|
|
|
std::cout << "Simt canonical GEMM problem size = (" << cutlass::gemm::GemmShape<kM, kN, kK>() <<")" << std::endl;
|
|
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
A.host_view(),
|
|
seed,
|
|
max,
|
|
min,
|
|
0
|
|
);
|
|
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
B.host_view(),
|
|
seed + 17,
|
|
max,
|
|
min,
|
|
0
|
|
);
|
|
|
|
#if 0 // Debug: fill A sequentially and B as Identity matrix for debugging
|
|
cutlass::reference::host::BlockFillSequential(
|
|
A.host_view().data(), A.host_view().capacity());
|
|
|
|
cutlass::reference::host::TensorFillIdentity(B.host_view());
|
|
#endif
|
|
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
C.host_view(),
|
|
seed + 31,
|
|
max,
|
|
min,
|
|
0
|
|
);
|
|
|
|
A.sync_device();
|
|
B.sync_device();
|
|
C.sync_device();
|
|
D.sync_device();
|
|
|
|
dim3 grid(1, 1);
|
|
dim3 block(32, 1, 1);
|
|
|
|
float alpha = 1.0f;
|
|
float beta = 0.0f;
|
|
|
|
kernel<<< grid, block >>>(
|
|
D.device_data(),
|
|
alpha,
|
|
A.device_data(),
|
|
B.device_data(),
|
|
beta,
|
|
C.device_data()
|
|
);
|
|
|
|
cudaError_t result = cudaDeviceSynchronize();
|
|
if (result != cudaSuccess) {
|
|
std::cerr << "Failed to synchronize device after kernel launch." << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
D.sync_host();
|
|
|
|
// Compute reference on host
|
|
cutlass::HostTensor<float, cutlass::layout::RowMajor> D_ref({kM, kN}, false);
|
|
cutlass::reference::host::TensorCopy(D_ref.host_view(), C.host_view());
|
|
|
|
cutlass::reference::host::Gemm<
|
|
float, cutlass::layout::RowMajor,
|
|
float, cutlass::layout::ColumnMajor,
|
|
float, cutlass::layout::RowMajor,
|
|
float, float> reference_gemm;
|
|
|
|
reference_gemm(
|
|
{kM, kN, kK},
|
|
alpha,
|
|
A.host_ref(),
|
|
B.host_ref(),
|
|
beta,
|
|
D_ref.host_ref(),
|
|
float()
|
|
);
|
|
|
|
// Verify reference matches computed
|
|
if (!cutlass::reference::host::TensorEquals(
|
|
D.host_view(),
|
|
D_ref.host_view())) {
|
|
|
|
std::cerr
|
|
<< "A =\n" << A.host_view()
|
|
<< "\n\nB = \n" << B.host_view()
|
|
<< "\n\nC = " << C.host_view()
|
|
<< "\n\nRef =\n" << D_ref.host_view()
|
|
<< "\n\nD =\n" << D.host_view() << "\n\n";
|
|
|
|
std::cerr << "Error - device results mismatch host reference." << std::endl;
|
|
|
|
return -1;
|
|
}
|
|
|
|
std::cout << "Passed" << std::endl;
|
|
|
|
return 0;
|
|
|
|
}
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|