Cutlass 1.3 Release (#42)
CUTLASS 1.3 Release - Efficient GEMM kernel targeting Volta Tensor Cores via mma.sync instruction added in CUDA 10.1.
This commit is contained in:
parent
19a9d64e3c
commit
877bdcace6
|
@ -1,7 +1,7 @@
|
|||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [1.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.1) (2018-12-19)
|
||||
* Resolved issue with sm50 and sm52 architectures
|
||||
## [1.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.0) (2019-03-20)
|
||||
* Efficient GEMM kernel targeting Volta Tensor Cores via `mma.sync` instruction added in CUDA 10.1.
|
||||
|
||||
## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26)
|
||||
* Parallelized reductions across threadblocks ("Split-K")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
@ -20,7 +20,7 @@
|
|||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.3.0)
|
||||
cmake_minimum_required(VERSION 3.3.0 FATAL_ERROR)
|
||||
|
||||
set(CUTLASS_LANGUAGES CXX)
|
||||
|
||||
|
@ -36,7 +36,8 @@ else()
|
|||
# FindCUDA fails to detect VS 2017 due to a changed directory format of the toolkits.
|
||||
# For this configuration we need CMake >= 3.9.0 to use the native CUDA support.
|
||||
if (WIN32 AND MSVC_VERSION GREATER 1800)
|
||||
message(FATAL_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
|
||||
message(SEND_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
|
||||
cmake_minimum_required(VERSION 3.9.0 FATAL_ERROR)
|
||||
endif()
|
||||
|
||||
# Fall back to the FindCUDA version to create an executable with CUDA files
|
||||
|
@ -52,7 +53,11 @@ if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
|
|||
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
|
||||
endif()
|
||||
|
||||
find_package(CUDA)
|
||||
find_package(CUDA REQUIRED)
|
||||
include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
|
||||
# Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include
|
||||
# paths by default, so we add it explicitly here.
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
###################################################################################################
|
||||
|
@ -61,9 +66,18 @@ find_package(Doxygen QUIET)
|
|||
#
|
||||
###################################################################################################
|
||||
|
||||
find_library(CUBLAS_LIBRARY cublas HINTS
|
||||
#
|
||||
# Conditionally enable cuBLAS
|
||||
#
|
||||
set(CUTLASS_ENABLE_CUBLAS ON CACHE BOOL "Enable CUTLASS Tests to build with cuBLAS library.")
|
||||
|
||||
if(CUTLASS_ENABLE_CUBLAS)
|
||||
|
||||
find_library(CUBLAS_LIBRARY cublas HINTS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
|
||||
endif()
|
||||
|
||||
|
||||
# By default we want to build in Release mode to ensure that we're getting best performance
|
||||
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
|
||||
|
@ -78,26 +92,56 @@ if(WIN32)
|
|||
endif()
|
||||
|
||||
if (WIN32)
|
||||
# Enable more warnings and treat as errors
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
|
||||
# Enable more warnings and treat as errors
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
|
||||
|
||||
# Disable warning on Unicode characters
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
|
||||
# Disable warning on Unicode characters
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
|
||||
|
||||
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
|
||||
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
|
||||
|
||||
# Verbose option
|
||||
if (${CUTLASS_NVCC_VERBOSE})
|
||||
string(APPEND NVCC_FLAGS " -v")
|
||||
endif()
|
||||
# Verbose option
|
||||
if (${CUTLASS_NVCC_VERBOSE})
|
||||
string(APPEND NVCC_FLAGS " -v")
|
||||
endif()
|
||||
endif(WIN32)
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS "50;60;61;70;75" CACHE STRING "The SM architectures to build code for.")
|
||||
set(CUTLASS_NVCC_ARCHS_DEFAULT "")
|
||||
if(NOT CUDA_VERSION VERSION_LESS 7.5)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 50)
|
||||
endif()
|
||||
if(NOT CUDA_VERSION VERSION_LESS 8.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 60 61)
|
||||
endif()
|
||||
if(NOT CUDA_VERSION VERSION_LESS 9.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 70)
|
||||
endif()
|
||||
if(NOT CUDA_VERSION VERSION_LESS 9.2)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 72)
|
||||
endif()
|
||||
if(NOT CUDA_VERSION VERSION_LESS 10.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 75)
|
||||
endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_DEFAULT} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
|
||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
|
||||
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
|
||||
if (CUDA_VERSION VERSION_LESS 10.1)
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT OFF)
|
||||
else()
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
||||
"Enable PTX mma instruction for collective matrix multiply operations.")
|
||||
|
||||
set(CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST ${CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST} CACHE BOOL
|
||||
"Enable more kernels instantiated in the perf suite. This might result in longer compiler time. ")
|
||||
|
||||
#
|
||||
# NOTE: running with asan and CUDA requires the following environment variable:
|
||||
#
|
||||
|
@ -131,6 +175,18 @@ foreach(ARCH ${CUTLASS_NVCC_ARCHS})
|
|||
endif()
|
||||
endforeach()
|
||||
|
||||
if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1")
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_CUBLAS)
|
||||
string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_CUBLAS=1")
|
||||
endif()
|
||||
|
||||
if (CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST)
|
||||
add_definitions(-DEXHAUSTIVE_PROF)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_NVCC_KEEP)
|
||||
string(APPEND NVCC_FLAGS " -keep")
|
||||
endif()
|
||||
|
@ -174,6 +230,7 @@ file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
|
|||
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
|
||||
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
|
||||
file(GLOB CUTLASS_REDUCTION RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/reduction/*.h )
|
||||
file(GLOB CUTLASS_LAYOUT_THREAD RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/layout/thread/*.h)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
|
@ -185,16 +242,24 @@ source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
|
|||
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
|
||||
source_group("cutlass\\device" FILES ${CUTLASS_DEVICE})
|
||||
source_group("cutlass\\reduction" FILES ${CUTLASS_REDUCTION})
|
||||
source_group("cutlass\\layout\\thread" FILES ${CUTLASS_LAYOUT_THREAD})
|
||||
source_group("cutlass" FILES ${CUTLASS_CORE})
|
||||
|
||||
add_library(CUTLASS INTERFACE)
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
|
||||
# Special policy introduced in CMake 3.13
|
||||
if (POLICY CMP0076)
|
||||
cmake_policy(SET CMP0076 NEW)
|
||||
endif()
|
||||
|
||||
target_sources(CUTLASS INTERFACE
|
||||
${CUTLASS_GEMM}
|
||||
${CUTLASS_UTIL}
|
||||
${CUTLASS_DEVICE}
|
||||
${CUTLASS_CORE}
|
||||
${CUTLASS_REDUCTION}
|
||||
${CUTLASS_LAYOUT_THREAD}
|
||||
)
|
||||
|
||||
target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
@ -206,6 +271,7 @@ add_custom_target(cutlass_ide SOURCES
|
|||
${CUTLASS_DEVICE}
|
||||
${CUTLASS_CORE}
|
||||
${CUTLASS_REDUCTION}
|
||||
${CUTLASS_LAYOUT_THREAD}
|
||||
)
|
||||
# Doxygen is available. Generate documentation
|
||||
if (DOXYGEN_FOUND)
|
||||
|
|
|
@ -14,7 +14,7 @@ CUTLASS core components, and to identify their role in implementing GEMM computa
|
|||
# <a name="S-design-patterns"></a> 1. Design Patterns
|
||||
|
||||
CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a
|
||||
flexible composition that an be easily applied to solve new problems related to Deep Learning and
|
||||
flexible composition that can be easily applied to solve new problems related to Deep Learning and
|
||||
linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given
|
||||
a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several
|
||||
design patterns are necessary to yield a composable structure while also satisfying these performance
|
||||
|
@ -31,7 +31,7 @@ CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvla
|
|||
|
||||
## <a name="S-patterns-tiles-iterators"></a> Tiles and Iterators
|
||||
|
||||
Efficient dense linear algebra computations emphasize data movement to match the execution of mathemtical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
|
||||
Efficient dense linear algebra computations emphasize data movement to match the execution of mathematical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
|
||||
specifying element type, size, and data layout. CUTLASS refers to subpartitions as _tiles_.
|
||||
|
||||
_Iterators_ are familiar design patterns in C++ that provide an abstraction for accessing individual
|
||||
|
@ -353,7 +353,7 @@ An example of splitK usage can be found [here](examples/06_splitK_gemm/splitK_ge
|
|||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
|
|
25
README.md
25
README.md
|
@ -1,8 +1,8 @@
|
|||
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
|
||||
|
||||
# CUTLASS 1.2
|
||||
# CUTLASS 1.3
|
||||
|
||||
_CUTLASS 1.2 - October 2018_
|
||||
_CUTLASS 1.3.0 - March 2019_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
|
@ -20,13 +20,18 @@ multiply-accumulate abstractions for 8-bit integer, half-precision floating
|
|||
point (FP16), single-precision floating point (FP32), and double-precision floating
|
||||
point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targeting
|
||||
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
|
||||
and beyond.
|
||||
and beyond. Even faster performance on Volta is possible via direct access to
|
||||
Volta Tenor Cores via `mma.sync` (added in CUDA 10.1).
|
||||
|
||||
CUTLASS 1.2 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
|
||||
CUTLASS 1.3 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
We describe the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
# What's New in CUTLASS 1.3
|
||||
_March 2019_
|
||||
* CUTLASS 1.3 includes an efficient GEMM implementation with the `mma.sync` instruction added in CUDA 10.1.
|
||||
|
||||
# What's New in CUTLASS 1.2
|
||||
_October 2018_
|
||||
* [Parallelized Reductions](CUTLASS.md#parallel-reductions-across-gemm-k)
|
||||
|
@ -63,8 +68,8 @@ when compiled with CUDA 10.0.
|
|||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS performs best when compiled with the [CUDA 10.0 Toolkit](ttps://developer.nvidia.com/cuda-toolkit).
|
||||
It is compatible with CUDA 9.0, 9.1, and 9.2, but these versions of the CUDA Toolkit do not support new Turing WMMA features.
|
||||
CUTLASS performs best when compiled with the [CUDA 10.1 Toolkit](ttps://developer.nvidia.com/cuda-toolkit).
|
||||
It is also compatible with CUDA 9.0, 9.1, 9.2, and 10.0.
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
|
@ -77,7 +82,7 @@ We have tested the following environments.
|
|||
| Ubuntu 18.04 | GCC 7.3.0 |
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU.
|
||||
any Maxwell-, Pascal-, Volta-, and Turing-architecture NVIDIA GPUs.
|
||||
|
||||
|**GPU**|
|
||||
|---|
|
||||
|
@ -220,6 +225,9 @@ Program usage:
|
|||
|
||||
# Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands
|
||||
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 --kernels=sgemm_nn,igemm_nn
|
||||
|
||||
# Executes GEMM kernel on Volta Tensor Cores
|
||||
$ ./tools/test/perf/cutlass_perf_test --kernels=s884gemm_nt
|
||||
```
|
||||
|
||||
# About
|
||||
|
@ -230,7 +238,7 @@ CUTLASS is released by NVIDIA Corporation as Open Source software under the
|
|||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
|
@ -253,4 +261,3 @@ Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
|||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1,380 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates wrapping direct issue of MMA instructions to Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/shape.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specifies internal data type for computation
|
||||
struct ComputeType {
|
||||
enum Kind {
|
||||
kBegin,
|
||||
kDefault, /// Compute type implied by operand and accumulator types
|
||||
kEnd
|
||||
};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Direct wrapper for native MMA instruction
|
||||
template <
|
||||
/// Warp-level matrix multiply-accumulate operation
|
||||
typename WmmaTile,
|
||||
/// Layout of A multiplicand
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Data type of A multiplicand
|
||||
typename ScalarA,
|
||||
/// Layout of B multiplicand
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Data type of A multiplicand
|
||||
typename ScalarB,
|
||||
/// Data type of accumulators
|
||||
typename ScalarC,
|
||||
/// Specifies particular compute type, overriding data types of operands
|
||||
ComputeType::Kind ComputeTy>
|
||||
inline __device__ void mma(ScalarA const A[], ScalarB const B[], ScalarC const C[], ScalarC D[]);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// 16x16x4
|
||||
//
|
||||
|
||||
//
|
||||
// FP16 accumulation
|
||||
//
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
half,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
half const c[],
|
||||
half d[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
unsigned const *C = reinterpret_cast<unsigned const *>(c);
|
||||
unsigned *D = reinterpret_cast<unsigned *>(d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
|
||||
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
half,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
half const c[],
|
||||
half d[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
unsigned const *C = reinterpret_cast<unsigned const *>(c);
|
||||
unsigned *D = reinterpret_cast<unsigned *>(d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
|
||||
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
half,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
half const c[],
|
||||
half d[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
unsigned const *C = reinterpret_cast<unsigned const *>(c);
|
||||
unsigned *D = reinterpret_cast<unsigned *>(d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
|
||||
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
half,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
half const c[],
|
||||
half d[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
unsigned const *C = reinterpret_cast<unsigned const *>(c);
|
||||
unsigned *D = reinterpret_cast<unsigned *>(d);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
|
||||
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
|
||||
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// FP32 accumulation
|
||||
//
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
float,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
float const C[],
|
||||
float D[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
||||
"{%12,%13,%14,%15,%16,%17,%18,%19};"
|
||||
: "=f"(D[0]),
|
||||
"=f"(D[1]),
|
||||
"=f"(D[2]),
|
||||
"=f"(D[3]),
|
||||
"=f"(D[4]),
|
||||
"=f"(D[5]),
|
||||
"=f"(D[6]),
|
||||
"=f"(D[7])
|
||||
: "r"(A[0]),
|
||||
"r"(A[1]),
|
||||
"r"(B[0]),
|
||||
"r"(B[1]),
|
||||
"f"(C[0]),
|
||||
"f"(C[1]),
|
||||
"f"(C[2]),
|
||||
"f"(C[3]),
|
||||
"f"(C[4]),
|
||||
"f"(C[5]),
|
||||
"f"(C[6]),
|
||||
"f"(C[7]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
float,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
float const C[],
|
||||
float D[]) {
|
||||
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
||||
"{%12,%13,%14,%15,%16,%17,%18,%19};"
|
||||
: "=f"(D[0]),
|
||||
"=f"(D[1]),
|
||||
"=f"(D[2]),
|
||||
"=f"(D[3]),
|
||||
"=f"(D[4]),
|
||||
"=f"(D[5]),
|
||||
"=f"(D[6]),
|
||||
"=f"(D[7])
|
||||
: "r"(A[0]),
|
||||
"r"(A[1]),
|
||||
"r"(B[0]),
|
||||
"r"(B[1]),
|
||||
"f"(C[0]),
|
||||
"f"(C[1]),
|
||||
"f"(C[2]),
|
||||
"f"(C[3]),
|
||||
"f"(C[4]),
|
||||
"f"(C[5]),
|
||||
"f"(C[6]),
|
||||
"f"(C[7]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
float,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
float const C[],
|
||||
float D[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
|
||||
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
||||
"{%12,%13,%14,%15,%16,%17,%18,%19};"
|
||||
: "=f"(D[0]),
|
||||
"=f"(D[1]),
|
||||
"=f"(D[2]),
|
||||
"=f"(D[3]),
|
||||
"=f"(D[4]),
|
||||
"=f"(D[5]),
|
||||
"=f"(D[6]),
|
||||
"=f"(D[7])
|
||||
: "r"(A[0]),
|
||||
"r"(A[1]),
|
||||
"r"(B[0]),
|
||||
"r"(B[1]),
|
||||
"f"(C[0]),
|
||||
"f"(C[1]),
|
||||
"f"(C[2]),
|
||||
"f"(C[3]),
|
||||
"f"(C[4]),
|
||||
"f"(C[5]),
|
||||
"f"(C[6]),
|
||||
"f"(C[7]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Volta mma.sync instruction
|
||||
template <>
|
||||
inline __device__ void mma<Shape<4, 16, 16>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
half,
|
||||
MatrixLayout::kRowMajor,
|
||||
half,
|
||||
float,
|
||||
ComputeType::kDefault>(half const a[],
|
||||
half const b[],
|
||||
float const C[],
|
||||
float D[]) {
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
unsigned const *A = reinterpret_cast<unsigned const *>(a);
|
||||
unsigned const *B = reinterpret_cast<unsigned const *>(b);
|
||||
|
||||
asm volatile ("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
|
||||
"{%12,%13,%14,%15,%16,%17,%18,%19};"
|
||||
: "=f"(D[0]),
|
||||
"=f"(D[1]),
|
||||
"=f"(D[2]),
|
||||
"=f"(D[3]),
|
||||
"=f"(D[4]),
|
||||
"=f"(D[5]),
|
||||
"=f"(D[6]),
|
||||
"=f"(D[7])
|
||||
: "r"(A[0]),
|
||||
"r"(A[1]),
|
||||
"r"(B[0]),
|
||||
"r"(B[1]),
|
||||
"f"(C[0]),
|
||||
"f"(C[1]),
|
||||
"f"(C[2]),
|
||||
"f"(C[3]),
|
||||
"f"(C[4]),
|
||||
"f"(C[5]),
|
||||
"f"(C[6]),
|
||||
"f"(C[7]));
|
||||
|
||||
#else
|
||||
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -29,11 +29,12 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUTLASS_MAJOR 1
|
||||
#define CUTLASS_MINOR 2
|
||||
#define CUTLASS_PATCH 1
|
||||
#define CUTLASS_MINOR 3
|
||||
#define CUTLASS_PATCH 0
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
#ifdef __NVCC__
|
||||
|
@ -47,9 +48,31 @@
|
|||
// CUTLASS_DEVICE is an error if not compiling device code
|
||||
#endif
|
||||
|
||||
// CUDA 10.1 introduces the mma instruction
|
||||
#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0
|
||||
#endif
|
||||
|
||||
// CUTLASS assert
|
||||
#define CUTLASS_ASSERT(x) assert(x)
|
||||
|
||||
#include "cutlass/util/performance_tuning.h"
|
||||
// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#define CUTLASS_PRAGMA_UNROLL #pragma unroll
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1
|
||||
|
||||
#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
|
||||
|
||||
#define CUTLASS_GEMM_LOOP_HEADER \
|
||||
asm volatile (".pragma \"nounroll\";\n");
|
||||
#else
|
||||
|
||||
#define CUTLASS_PRAGMA_UNROLL
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL
|
||||
#define CUTLASS_GEMM_LOOP_HEADER
|
||||
#define CUTLASS_GEMM_LOOP
|
||||
|
||||
#endif
|
||||
|
||||
// A small helper class to dump a type at compile time
|
||||
// Usage:: DumpType<Class>::Class
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -98,9 +98,9 @@ struct StorageType<1> {
|
|||
template <typename Element_, int kElements_, size_t kAlignment_ = 16>
|
||||
struct Fragment : public AlignedStruct<kAlignment_> {
|
||||
/// Make sure the alignment makes sense wrt the size of elements.
|
||||
static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
|
||||
static_assert(int(kAlignment_) == 16 || int(kAlignment_) >= sizeof(Element_), "Alignment is too small");
|
||||
/// Alignment must be a power of two
|
||||
static_assert(is_pow2<kAlignment_>::value, "Alignment must be a power of two");
|
||||
static_assert(is_pow2<int(kAlignment_)>::value, "Alignment must be a power of two");
|
||||
|
||||
/// This class.
|
||||
typedef Fragment<Element_, kElements_> This_;
|
||||
|
@ -109,27 +109,31 @@ struct Fragment : public AlignedStruct<kAlignment_> {
|
|||
/// The number of elements.
|
||||
static int const kElements = kElements_;
|
||||
/// Alignment
|
||||
static int const kAlignment = kAlignment_;
|
||||
static int const kAlignment = int(kAlignment_);
|
||||
|
||||
/// Clear a fragment.
|
||||
CUTLASS_HOST_DEVICE void clear() {
|
||||
// Avoid element-wise access for sub 32b element type
|
||||
if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
|
||||
uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
|
||||
ptr[i] = uint64_t(0);
|
||||
}
|
||||
} else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
|
||||
uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
|
||||
ptr[i] = uint32_t(0);
|
||||
}
|
||||
} else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
|
||||
uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
|
||||
ptr[i] = uint16_t(0);
|
||||
}
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kElements; ++i) {
|
||||
storage[i] = 0;
|
||||
}
|
||||
|
@ -146,7 +150,7 @@ struct Fragment : public AlignedStruct<kAlignment_> {
|
|||
|
||||
private:
|
||||
/// Storage type to use for Elements
|
||||
typedef typename StorageType<kAlignment_>::Type StorageType;
|
||||
typedef typename StorageType<int(kAlignment_)>::Type StorageType;
|
||||
|
||||
/// Number of elements in the storage
|
||||
static int const kStorageCount =
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -68,7 +68,6 @@ struct FragmentMultiplyAdd {
|
|||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = b[j * kReduction + 0];
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -33,6 +33,8 @@
|
|||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
@ -47,7 +49,8 @@ struct DeviceGemm {
|
|||
#if !defined(__CUDACC_RTC__)
|
||||
/// Launch the kernels in order
|
||||
static __host__ cudaError_t launch(Params const& params) {
|
||||
Traits::GemmTraits::KernelClass::launch(params.GemmParams);
|
||||
//Traits::GemmTraits::KernelClass::launch(params.GemmParams);
|
||||
Gemm<typename Traits::GemmTraits>::launch(params.GemmParams);
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
return err;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -73,7 +73,7 @@ struct SplitkPIGemmTraits {
|
|||
/// The pointer to workspace memory
|
||||
ScalarAccum *workspace_ptr;
|
||||
///
|
||||
int workspace_size;
|
||||
size_t workspace_size;
|
||||
/// The Params for the first kernel
|
||||
typename GemmTraits::Params GemmParams;
|
||||
/// The Params for the second kernel
|
||||
|
@ -112,7 +112,8 @@ struct SplitkPIGemmTraits {
|
|||
Index ldc_,
|
||||
ScalarD* d_d_,
|
||||
Index ldd_,
|
||||
ScalarAccum *workspace_ptr_) {
|
||||
ScalarAccum *workspace_ptr_,
|
||||
Index partitionK_multiple = 1) {
|
||||
|
||||
workspace_ptr = workspace_ptr_;
|
||||
|
||||
|
@ -133,7 +134,7 @@ struct SplitkPIGemmTraits {
|
|||
TensorRef<typename GemmTraits::ScalarC const, 2>(workspace_ptr, problem_size.m()), /*m = ldc, workspace is not transposed and is packed*/
|
||||
TensorRef<typename GemmTraits::ScalarD, 2>(workspace_ptr, problem_size.m()) /*m = ldd, workspace is not transposed and is packed*/
|
||||
);
|
||||
GemmParams.initialize(desc, ReductionTraits::ReductionSize);
|
||||
GemmParams.initialize(desc, ReductionTraits::ReductionSize, partitionK_multiple);
|
||||
|
||||
|
||||
//call batched reduction (second kernel) param
|
||||
|
@ -155,9 +156,12 @@ struct SplitkPIGemmTraits {
|
|||
// workspace will be used to store D (output) from the first gemm kernel (not D of the entire gemm)
|
||||
// note typedef typename GemmTraits::ScalarD ScalarAccum;
|
||||
// workspace of size of M * N * Reduction
|
||||
int required_workspace_memory_in_byte(){
|
||||
size_t required_workspace_memory_in_byte(){
|
||||
assert(problem_size_initialized == true);
|
||||
workspace_size = problem_size.n() * problem_size.m() * ReductionTraits::ReductionSize * static_cast<int>(sizeof(ScalarAccum));
|
||||
workspace_size = static_cast<size_t>(problem_size.n()) *
|
||||
static_cast<size_t>(problem_size.m()) *
|
||||
static_cast<size_t>(ReductionTraits::ReductionSize) *
|
||||
sizeof(ScalarAccum);
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -29,6 +29,7 @@
|
|||
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
@ -69,8 +70,10 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, float> {
|
|||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
|
||||
d[j * AccumulatorsPerThread::kW + i] = static_cast<ScalarC>(a[i]) * static_cast<ScalarC>(b[j]) + c[j * AccumulatorsPerThread::kW + i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -33,20 +33,30 @@
|
|||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
#include <cstdio>
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM kernel with launch bounds specified
|
||||
template <typename Gemm_>
|
||||
__global__ __launch_bounds__(Gemm_::kThreads)
|
||||
void gemm_kernel(typename Gemm_::Params params) {
|
||||
// Declare shared memory.
|
||||
__shared__ typename Gemm_::SharedStorage shared_storage;
|
||||
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int GemmSharedStorageBase[];
|
||||
|
||||
// Declare pointer to dynamic shared memory.
|
||||
typename Gemm_::SharedStorage *shared_storage =
|
||||
reinterpret_cast<typename Gemm_::SharedStorage *>(GemmSharedStorageBase);
|
||||
|
||||
// Construct the GEMM object.
|
||||
Gemm_ gemm(params, shared_storage);
|
||||
Gemm_ gemm(params, *shared_storage);
|
||||
|
||||
// Run GEMM.
|
||||
gemm.multiply_add();
|
||||
}
|
||||
|
@ -57,11 +67,17 @@ void gemm_kernel(typename Gemm_::Params params) {
|
|||
template <typename Gemm_>
|
||||
__global__ /* __launch_bounds__(Gemm_::kThreads) */
|
||||
void gemm_kernel_nolb(typename Gemm_::Params params) {
|
||||
// Declare shared memory.
|
||||
__shared__ typename Gemm_::SharedStorage shared_storage;
|
||||
|
||||
// Dynamic shared memory base pointer
|
||||
extern __shared__ int GemmSharedStorageBase[];
|
||||
|
||||
// Declare pointer to dynamic shared memory.
|
||||
typename Gemm_::SharedStorage *shared_storage =
|
||||
reinterpret_cast<typename Gemm_::SharedStorage *>(GemmSharedStorageBase);
|
||||
|
||||
// Construct the GEMM object.
|
||||
Gemm_ gemm(params, shared_storage);
|
||||
Gemm_ gemm(params, *shared_storage);
|
||||
|
||||
// Run GEMM.
|
||||
gemm.multiply_add();
|
||||
}
|
||||
|
@ -72,7 +88,31 @@ void gemm_kernel_nolb(typename Gemm_::Params params) {
|
|||
template <typename Gemm, bool WithLaunchBounds>
|
||||
struct Launch {
|
||||
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
|
||||
gemm_kernel<Gemm><<< grid, block, 0, stream >>>(params);
|
||||
|
||||
int smem_size = int(sizeof(typename Gemm::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
gemm_kernel<Gemm>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size
|
||||
);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
gemm_kernel_nolb<Gemm>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||
100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
gemm_kernel<Gemm><<< grid, block, sizeof(typename Gemm::SharedStorage), stream >>>(params);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -82,50 +122,51 @@ struct Launch {
|
|||
template <typename Gemm>
|
||||
struct Launch<Gemm, false> {
|
||||
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
|
||||
gemm_kernel_nolb<Gemm><<< grid, block, 0, stream >>>(params);
|
||||
int smem_size = int(sizeof(typename Gemm::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
gemm_kernel_nolb<Gemm>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size
|
||||
);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
gemm_kernel_nolb<Gemm>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout,
|
||||
100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
// throw exception?
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
gemm_kernel_nolb<Gemm><<<
|
||||
grid,
|
||||
block,
|
||||
smem_size,
|
||||
stream >>>(params);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
template <typename Traits_>
|
||||
struct Gemm {
|
||||
/// This class.
|
||||
typedef Gemm<GemmTraits_> This_;
|
||||
|
||||
/// The traits.
|
||||
typedef GemmTraits_ Traits;
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// The scalar for A.
|
||||
typedef typename Traits::ScalarA ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef typename Traits::ScalarB ScalarB;
|
||||
/// The scalar in the epilogue.
|
||||
typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
|
||||
/// The scalar for C.
|
||||
typedef typename Traits::Epilogue::ScalarC ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename Traits::Epilogue::ScalarD ScalarD;
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// Define the mainloop iteration size
|
||||
typedef typename Traits::MultiplyAdd MultiplyAdd;
|
||||
|
||||
/// The number of threads.
|
||||
static int const kThreads = Traits::GemmConfig::kThreads;
|
||||
|
||||
// Number of warp-level multiply-accumulate steps executed by each warp.
|
||||
static Index const kWarpGemmSteps =
|
||||
Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
|
||||
static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
|
||||
typedef Traits_ Traits;
|
||||
|
||||
/// Use the params object defined in traits
|
||||
typedef typename Traits::Params Params;
|
||||
|
||||
typedef typename Traits::KernelClass KernelClass;
|
||||
|
||||
//
|
||||
// Static function members
|
||||
//
|
||||
|
@ -137,7 +178,7 @@ struct Gemm {
|
|||
cudaStream_t stream = cudaStreamDefault) {
|
||||
|
||||
// Launch the kernel.
|
||||
Launch<This_, GemmTraits_::GemmConfig::kLaunchBounds>(
|
||||
Launch<KernelClass, Traits::GemmConfig::kLaunchBounds>(
|
||||
params, params.grid, params.block, stream);
|
||||
|
||||
return cudaGetLastError();
|
||||
|
@ -164,189 +205,6 @@ struct Gemm {
|
|||
}
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
|
||||
: params(params_), shared_storage(shared_storage_) {}
|
||||
|
||||
/// Computes a warp-level GEMM on data held in shared memory
|
||||
template <bool Residue, bool LastIteration>
|
||||
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
|
||||
typename Traits::SharedStream& shared_load_stream,
|
||||
typename MultiplyAdd::Accumulators& accumulators,
|
||||
Index outer_k) {
|
||||
// If residue portion and not calculating residue in prolog, update residue predicates now.
|
||||
if (Residue && outer_k <= Traits::OutputTile::kD) {
|
||||
global_to_shared_stream.residue(outer_k);
|
||||
}
|
||||
|
||||
// Load data for the next iteration of the main loop (unless it's the last iteration).
|
||||
if (!LastIteration) {
|
||||
global_to_shared_stream.copy();
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int step = 0; step < kWarpGemmSteps - 1; ++step) {
|
||||
// Trigger the copy from shared memory for the next A/B values.
|
||||
shared_load_stream.copy(step + 1);
|
||||
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(step);
|
||||
|
||||
MultiplyAdd multiply_add;
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
|
||||
shared_load_stream.fragment_b(step),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
|
||||
// Make sure the data from shared memory has been entirely consumed.
|
||||
Traits::shared_load_fence(true);
|
||||
|
||||
// Commit the data in shared memory for A/B.
|
||||
if (!LastIteration) {
|
||||
global_to_shared_stream.commit();
|
||||
}
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(true);
|
||||
|
||||
if (!LastIteration) {
|
||||
// Move to the next stage for the load (if it makes sense).
|
||||
shared_load_stream.inc_stage();
|
||||
// Trigger the copy from shared memory for the next loop iteration.
|
||||
shared_load_stream.copy(0);
|
||||
}
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(kWarpGemmSteps - 1);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(kWarpGemmSteps - 1),
|
||||
shared_load_stream.fragment_b(kWarpGemmSteps - 1),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
|
||||
/// Do the GEMM.
|
||||
CUTLASS_DEVICE void multiply_add() {
|
||||
// Swizzle the IDs of the block (to enable better cache behavior).
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
Coord<3> threadblock_offset =
|
||||
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::OutputTile>());
|
||||
|
||||
// We may want to use shared memory to clear the registers.
|
||||
typedef typename Traits::ClearAccumulators ClearAccumulators;
|
||||
|
||||
// Get the bounds for each thread, it maybe different than problem_size
|
||||
Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
|
||||
params.partitionK_range);
|
||||
|
||||
// The streams to read A/B from global memory to shared memory.
|
||||
typename Traits::GlobalLoadStream global_to_shared_stream(
|
||||
params.global_to_shared_stream,
|
||||
shared_storage.main_loop.global_to_shared_stream,
|
||||
shared_storage.main_loop.threadblock_tile.reference(),
|
||||
bounds,
|
||||
threadblock_offset);
|
||||
|
||||
// update A and B pointer offset based on batch_id and batch_stride_offset
|
||||
global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
|
||||
|
||||
// Create the accumulator clear.
|
||||
ClearAccumulators clear;
|
||||
|
||||
// Deal with residue in prolog.
|
||||
// global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
|
||||
global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
|
||||
|
||||
// Fetch the fragments for A and B from global memory.
|
||||
global_to_shared_stream.copy();
|
||||
|
||||
// Copy the elements to shared memory (after transformation if needed).
|
||||
global_to_shared_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(false);
|
||||
|
||||
// Rollback to the beginning of the first tile (if residue exists).
|
||||
// global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
|
||||
global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
|
||||
|
||||
// The stream of data from shared memory to fragments.
|
||||
typename Traits::SharedStream shared_load_stream(
|
||||
params.shared_stream,
|
||||
shared_storage.main_loop.threadblock_tile.reference());
|
||||
|
||||
// Trigger the copy from shared memory for the 1st stream.
|
||||
shared_load_stream.copy(0);
|
||||
|
||||
// Allocate the accumulators.
|
||||
typename MultiplyAdd::Accumulators accumulators;
|
||||
|
||||
// Clear the accumulators.
|
||||
clear.clear(accumulators);
|
||||
|
||||
// Initial index
|
||||
// Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
|
||||
// problem_size[0] might be bigger than bounds[0]
|
||||
Index outer_k = bounds[0] - Traits::OutputTile::kD;
|
||||
// Check if we are computing residue in prolog or not.
|
||||
if (Traits::GemmConfig::kResidueInProlog) {
|
||||
// Execute all mainloop iterations but the last one.
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
|
||||
// Don't load data for the last "residue" portion since we've already computed the residue.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<false, true>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
} else {
|
||||
// When kResidueSeparate = true, execute all mainloop iterations but the last two without any
|
||||
// consideration for K-residue or predicate updates. This improves the steady state of some
|
||||
// kernels.
|
||||
if (Traits::GemmConfig::kResidueSeparate) {
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute remaining tiles with K-residue predicate updates enabled.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<true, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
}
|
||||
|
||||
// Epilogue.
|
||||
typedef typename Traits::Epilogue Epilogue;
|
||||
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
|
||||
epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
|
||||
}
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
/// The shared storage.
|
||||
SharedStorage& shared_storage;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -0,0 +1,274 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements a software-pipelined efficient GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Traits_>
|
||||
struct GemmMainloop {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// The traits.
|
||||
typedef Traits_ Traits;
|
||||
|
||||
/// The GEMM mainloop
|
||||
typedef typename Traits::KernelClass KernelClass;
|
||||
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// The scalar for A.
|
||||
typedef typename Traits::ScalarA ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef typename Traits::ScalarB ScalarB;
|
||||
/// The scalar in the epilogue.
|
||||
typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
|
||||
/// The scalar for C.
|
||||
typedef typename Traits::Epilogue::ScalarC ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename Traits::Epilogue::ScalarD ScalarD;
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// Define the mainloop iteration size
|
||||
typedef typename Traits::MultiplyAdd MultiplyAdd;
|
||||
|
||||
/// The number of threads.
|
||||
static int const kThreads = Traits::GemmConfig::kThreads;
|
||||
|
||||
// Number of warp-level multiply-accumulate steps executed by each warp.
|
||||
static Index const kWarpGemmSteps =
|
||||
Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
/*
|
||||
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
|
||||
static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
|
||||
*/
|
||||
|
||||
/// Use the params object defined in traits
|
||||
typedef typename Traits::Params Params;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
|
||||
/// SharedStorage object
|
||||
SharedStorage& shared_storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmMainloop(Params const& params_, SharedStorage& shared_storage_)
|
||||
: params(params_), shared_storage(shared_storage_) {}
|
||||
|
||||
/// Fetches global stream pair
|
||||
template <bool Residue>
|
||||
CUTLASS_DEVICE void fetch_global(typename Traits::GlobalLoadStream& global_to_shared_stream,
|
||||
Index outer_k) {
|
||||
// If residue portion and not calculating residue in prolog, update residue predicates now.
|
||||
if (Residue) {
|
||||
global_to_shared_stream.residue(outer_k);
|
||||
}
|
||||
global_to_shared_stream.copy();
|
||||
}
|
||||
|
||||
/// Computes a warp-level GEMM on data held in shared memory
|
||||
template <bool Residue, bool LastIteration>
|
||||
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
|
||||
typename Traits::SharedStream& shared_load_stream,
|
||||
typename MultiplyAdd::Accumulators& accumulators,
|
||||
Index outer_k) {
|
||||
|
||||
// Whether to load global stream before loading shared stream
|
||||
const bool kGlobalStreamFirst = (kWarpGemmSteps <= 4);
|
||||
|
||||
// Load data for the next iteration of the main loop (unless it's the last iteration).
|
||||
if (kGlobalStreamFirst && !LastIteration) {
|
||||
fetch_global<Residue>(global_to_shared_stream, outer_k);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int step = 0; step < kWarpGemmSteps; ++step) {
|
||||
|
||||
// Trigger the copy from shared memory for the next A/B values.
|
||||
shared_load_stream.copy((step + 1) % kWarpGemmSteps);
|
||||
|
||||
// Load data for the next iteration of the main loop (unless it's the last iteration).
|
||||
if (!kGlobalStreamFirst && (step == 0) && !LastIteration) {
|
||||
fetch_global<Residue>(global_to_shared_stream, outer_k);
|
||||
}
|
||||
|
||||
if (step == kWarpGemmSteps - 2) {
|
||||
// Make sure the data from shared memory has been entirely consumed.
|
||||
Traits::shared_load_fence(true);
|
||||
|
||||
global_to_shared_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(true);
|
||||
|
||||
// Move to the next stage for the load (if it makes sense).
|
||||
shared_load_stream.inc_stage();
|
||||
}
|
||||
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(step);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
|
||||
shared_load_stream.fragment_b(step),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
}
|
||||
|
||||
/// Do the GEMM.
|
||||
CUTLASS_DEVICE void multiply_add() {
|
||||
// Swizzle the IDs of the block (to enable better cache behavior).
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
Coord<3> threadblock_offset =
|
||||
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::OutputTile>());
|
||||
|
||||
// We may want to use shared memory to clear the registers.
|
||||
typedef typename Traits::ClearAccumulators ClearAccumulators;
|
||||
|
||||
// Get the bounds for each thread, it maybe different than problem_size
|
||||
Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
|
||||
params.partitionK_range);
|
||||
|
||||
// The streams to read A/B from global memory to shared memory.
|
||||
typename Traits::GlobalLoadStream global_to_shared_stream(
|
||||
params.global_to_shared_stream,
|
||||
shared_storage.main_loop.global_to_shared_stream,
|
||||
shared_storage.main_loop.threadblock_tile.reference(),
|
||||
bounds,
|
||||
threadblock_offset);
|
||||
|
||||
// update A and B pointer offset based on batch_id and batch_stride_offset
|
||||
global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
|
||||
|
||||
// Create the accumulator clear.
|
||||
ClearAccumulators clear;
|
||||
|
||||
// Deal with residue in prolog.
|
||||
// global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
|
||||
global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
|
||||
|
||||
// Fetch the fragments for A and B from global memory.
|
||||
global_to_shared_stream.copy();
|
||||
|
||||
// Copy the elements to shared memory (after transformation if needed).
|
||||
global_to_shared_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(false);
|
||||
|
||||
// Rollback to the beginning of the first tile (if residue exists).
|
||||
// global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
|
||||
global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
|
||||
|
||||
// The stream of data from shared memory to fragments.
|
||||
typename Traits::SharedStream shared_load_stream(
|
||||
params.shared_stream,
|
||||
shared_storage.main_loop.threadblock_tile.reference());
|
||||
|
||||
// Trigger the copy from shared memory for the 1st stream.
|
||||
shared_load_stream.copy(0);
|
||||
|
||||
// Allocate the accumulators.
|
||||
typename MultiplyAdd::Accumulators accumulators;
|
||||
|
||||
// Clear the accumulators.
|
||||
clear.clear(accumulators);
|
||||
|
||||
// Initial index
|
||||
// Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
|
||||
// problem_size[0] might be bigger than bounds[0]
|
||||
Index outer_k = bounds[0] - Traits::OutputTile::kD;
|
||||
// Check if we are computing residue in prolog or not.
|
||||
if (Traits::GemmConfig::kResidueInProlog) {
|
||||
// Execute all mainloop iterations but the last one.
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
|
||||
CUTLASS_GEMM_LOOP_HEADER
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
|
||||
consume_tile<false, true>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
} else {
|
||||
// When kResidueSeparate = true, execute all mainloop iterations but the last two without any
|
||||
// consideration for K-residue or predicate updates. This improves the steady state of some
|
||||
// kernels.
|
||||
if (Traits::GemmConfig::kResidueSeparate) {
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
CUTLASS_GEMM_LOOP_HEADER
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute remaining tiles with K-residue predicate updates enabled.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
CUTLASS_GEMM_LOOP_HEADER
|
||||
consume_tile<true, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
}
|
||||
|
||||
typedef typename Traits::Epilogue Epilogue;
|
||||
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
|
||||
epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -91,8 +91,16 @@ struct SharedLoadStream {
|
|||
transformer = Transformer();
|
||||
}
|
||||
|
||||
/// Clears the fragment
|
||||
CUTLASS_DEVICE void clear() {
|
||||
fetched[0].clear();
|
||||
fetched[1].clear();
|
||||
transformed[0].clear();
|
||||
transformed[1].clear();
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
CUTLASS_DEVICE void copy() {
|
||||
iterator.load_post_increment(fetched[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -140,14 +140,19 @@ struct GlobalLoadStreamPair {
|
|||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
|
||||
stream_a.copy();
|
||||
|
||||
stream_b.copy();
|
||||
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
stream_a.commit();
|
||||
|
||||
stream_b.commit();
|
||||
|
||||
}
|
||||
|
||||
/// Execute the residue code.
|
||||
|
@ -233,6 +238,13 @@ struct SharedStreamPair {
|
|||
stream_b.commit(step);
|
||||
}
|
||||
|
||||
/// Clears all fragments
|
||||
CUTLASS_DEVICE
|
||||
void clear() {
|
||||
stream_a.clear();
|
||||
stream_b.clear();
|
||||
}
|
||||
|
||||
/// The fragment A.
|
||||
CUTLASS_DEVICE
|
||||
typename StreamA::TransformedFragment const &fragment_a(int step) const {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -42,7 +42,7 @@
|
|||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_mainloop.h"
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
@ -359,7 +359,7 @@ struct GemmTraits {
|
|||
ClearAccumulators_> This_;
|
||||
|
||||
/// The struct that consumes this Traits
|
||||
typedef typename cutlass::gemm::Gemm<This_> KernelClass;
|
||||
typedef typename cutlass::gemm::GemmMainloop<This_> KernelClass;
|
||||
|
||||
/// The configuration.
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
|
@ -544,16 +544,26 @@ struct GemmTraits {
|
|||
|
||||
/// Helper to construct a partitionedK GEMM params
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_) {
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc,
|
||||
Index partitionK_count_,
|
||||
Index partitionK_multiple_ = 1 // each partition will be mulitples of partitionK_multiple_
|
||||
) {
|
||||
// partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
|
||||
// the problem_size of each batch is (lastK_size, n, m)
|
||||
// add more comments here
|
||||
// the k range for every batch excpet the last one
|
||||
//assert(partitionK_count_ > 0);
|
||||
partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
|
||||
partitionK_range = partitionK_range - (partitionK_range % partitionK_multiple_);
|
||||
// the k range of the last batch
|
||||
// int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
|
||||
int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
|
||||
|
||||
assert((partitionK_range % partitionK_multiple_) == 0);
|
||||
assert(partitionK_range > 0);
|
||||
assert((lastK_range % partitionK_multiple_) == 0);
|
||||
assert(lastK_range > 0);
|
||||
|
||||
int k_size = lastK_range;
|
||||
int lda = partitonK_desc.A.stride(0);
|
||||
int ldb = partitonK_desc.B.stride(0);
|
||||
|
@ -641,7 +651,8 @@ struct GemmTraits {
|
|||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
Index partitionK_count_) {
|
||||
Index partitionK_count_,
|
||||
Index partitionK_multiple_ = 1) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, 1),
|
||||
|
@ -654,7 +665,7 @@ struct GemmTraits {
|
|||
);
|
||||
|
||||
|
||||
return this->initialize(desc, partitionK_count_);
|
||||
return this->initialize(desc, partitionK_count_, partitionK_multiple_);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -29,7 +29,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
@ -66,6 +65,8 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, half> {
|
|||
/// Make sure there's an even number of elements in both dimensions.
|
||||
static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
|
||||
static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
|
||||
static_assert(AccumulatorsPerThread::kH >= 2 && AccumulatorsPerThread::kW >= 2,
|
||||
"HGEMM expects at least 2x2 accmulator tiles per thread.");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
@ -84,7 +85,10 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, half> {
|
|||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
|
||||
// The offsets in the output fragment.
|
||||
int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -38,7 +38,7 @@
|
|||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/hgemm_global_tile.h"
|
||||
#include "cutlass/gemm/hgemm_multiply_add.h"
|
||||
#include "cutlass/gemm/hgemm_swizzle.h"
|
||||
#include "cutlass/layout/thread/transform.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
@ -107,7 +107,8 @@ struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
|
|||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef HgemmSwizzle<Iterator_> Transformer;
|
||||
typedef typename Iterator_::FragmentShape FragmentShape;
|
||||
typedef cutlass::layout::thread::Transform<FragmentShape, 2, half, cutlass::MatrixLayout::RowMajor, half, cutlass::MatrixLayout::ColumnMajor > Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -122,7 +123,8 @@ struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
|
|||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef HgemmSwizzle<Iterator_> Transformer;
|
||||
typedef typename Iterator_::FragmentShape FragmentShape;
|
||||
typedef cutlass::layout::thread::Transform<FragmentShape, 2, half, cutlass::MatrixLayout::RowMajor, half, cutlass::MatrixLayout::ColumnMajor > Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -28,8 +28,9 @@
|
|||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610))
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
@ -44,6 +45,11 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
|
|||
typedef Shape<4, 1, 1> InstructionShape;
|
||||
/// Shape of the thread-level GEMM (K-by-N-by-M)
|
||||
typedef ThreadGemmShape_ ThreadGemmShape;
|
||||
|
||||
/// Thread-level GEMM (N-by-M) must be a multiple of 32.
|
||||
static_assert((ThreadGemmShape::kH * ThreadGemmShape::kW) % 32 == 0,
|
||||
"Thread-level GEMM (N-by-M) must be multiple of 32");
|
||||
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef ThreadGemmShape AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
|
@ -72,19 +78,18 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
|
|||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)
|
||||
// The inputs.
|
||||
int const* a_int = reinterpret_cast<int const*>(&a[0]);
|
||||
int const* b_int = reinterpret_cast<int const*>(&b[0]);
|
||||
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
|
||||
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
||||
: "=r"(d[j * AccumulatorsPerThread::kW + i])
|
||||
: "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -92,3 +97,5 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
|
|||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610))
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -60,6 +60,7 @@ struct IgemmSwizzle {
|
|||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
|
||||
|
||||
// Expose src/dst as int arrays.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -39,7 +39,7 @@
|
|||
#include "cutlass/gemm/igemm_epilogue.h"
|
||||
#include "cutlass/gemm/igemm_global_tile.h"
|
||||
#include "cutlass/gemm/igemm_multiply_add.h"
|
||||
#include "cutlass/gemm/igemm_swizzle.h"
|
||||
#include "cutlass/layout/thread/transform.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
@ -90,9 +90,10 @@ struct IgemmConfig : public GemmConfig<
|
|||
/// kResidueSeparate
|
||||
false,
|
||||
/// kResidueInPrologue
|
||||
false,
|
||||
true,
|
||||
/// kLaunchBounds
|
||||
false> {};
|
||||
false>
|
||||
{};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
@ -380,7 +381,8 @@ struct IgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
|
|||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef IgemmSwizzle<Iterator_> Transformer;
|
||||
typedef typename Iterator_::FragmentShape FragmentShape;
|
||||
typedef cutlass::layout::thread::Transform<FragmentShape, 2, int8_t, cutlass::MatrixLayout::RowMajor, int8_t, cutlass::MatrixLayout::ColumnMajor > Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -395,7 +397,8 @@ struct IgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
|
|||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef IgemmSwizzle<Iterator_> Transformer;
|
||||
typedef typename Iterator_::FragmentShape FragmentShape;
|
||||
typedef cutlass::layout::thread::Transform<FragmentShape, 2, int8_t, cutlass::MatrixLayout::RowMajor, int8_t, cutlass::MatrixLayout::ColumnMajor > Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -0,0 +1,284 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with
|
||||
the computed matrix product.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename EpilogueTraits_>
|
||||
struct MMAEpilogue {
|
||||
/// The traits class.
|
||||
typedef EpilogueTraits_ Traits;
|
||||
|
||||
/// The params.
|
||||
typedef typename Traits::Params Params;
|
||||
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// Defines a tiling of the EpilogueTile over the entire threadblock GEMM tile
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
|
||||
/// The output tile.
|
||||
typedef typename Traits::OutputTile OutputTile;
|
||||
|
||||
/// Accumulators to store in the epilogue
|
||||
typedef typename Traits::Accumulators Accumulators;
|
||||
|
||||
/// A functor to copy a slice of accumulators for a given epilogue iteration
|
||||
typedef typename Traits::SelectAccumulators SelectAccumulators;
|
||||
|
||||
/// The iterator to load source matrix from global memory.
|
||||
typedef typename Traits::GlobalLoadStreamC GlobalLoadStreamC;
|
||||
|
||||
/// The iterator to store the final GEMM computation to global memory.
|
||||
typedef typename Traits::GlobalStoreStreamD GlobalStoreStreamD;
|
||||
|
||||
/// The stream to store matrix product to shared memory
|
||||
typedef typename Traits::SharedStoreStreamD SharedStoreStreamD;
|
||||
|
||||
/// The stream to load the matrix product from shared memory
|
||||
typedef typename Traits::SharedLoadStreamD SharedLoadStreamD;
|
||||
|
||||
/// The functor in charge of the math.
|
||||
typedef typename Traits::Functor Functor;
|
||||
|
||||
/// The scalar type used by the epilogue functor.
|
||||
typedef typename Functor::Scalar Scalar;
|
||||
|
||||
/// The scalar type of the source accumulator matrix.
|
||||
typedef typename Traits::ScalarC ScalarC;
|
||||
|
||||
/// The scalar type of the destination accumulator matrix.
|
||||
typedef typename Traits::ScalarD ScalarD;
|
||||
|
||||
/// The index type.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue.
|
||||
typedef typename Traits::GlobalOffset GlobalOffset;
|
||||
|
||||
///
|
||||
typedef typename Traits::GlobalDataLayout GlobalDataLayout;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
|
||||
/// The shared storage.
|
||||
SharedStorage& shared_storage;
|
||||
|
||||
/// The dimensions of the GEMM.
|
||||
gemm::GemmCoord problem_size;
|
||||
|
||||
/// Epilogue functor
|
||||
Functor functor;
|
||||
|
||||
// Functor to select a set of accumulators
|
||||
SelectAccumulators select_accumulators;
|
||||
|
||||
|
||||
// Functor to compute the global offset relative to the threadblock for each iteration
|
||||
// of the epilogue.
|
||||
GlobalOffset global_offset;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE MMAEpilogue(
|
||||
Params const& params_,
|
||||
SharedStorage& shared_storage_,
|
||||
Coord<3> const& _problem_size,
|
||||
SelectAccumulators _select_accumulators = SelectAccumulators(),
|
||||
GlobalOffset _global_offset = GlobalOffset()
|
||||
):
|
||||
params(params_),
|
||||
shared_storage(shared_storage_),
|
||||
problem_size(_problem_size),
|
||||
functor(params_.functor),
|
||||
select_accumulators(_select_accumulators),
|
||||
global_offset(_global_offset) {}
|
||||
|
||||
/// Execute the epilogue.
|
||||
CUTLASS_DEVICE void epilogue(
|
||||
Accumulators& accumulators,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
|
||||
int batch_id = 0) {
|
||||
|
||||
if (functor.source_required()) {
|
||||
epilogue_with_or_without_beta<true>(accumulators, threadblock_offset, batch_id);
|
||||
}
|
||||
else {
|
||||
epilogue_with_or_without_beta<false>(accumulators, threadblock_offset, batch_id);
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
/// Execute the epilogue.
|
||||
template <bool kSourceRequired>
|
||||
CUTLASS_DEVICE void epilogue_with_or_without_beta(
|
||||
Accumulators& accumulators,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
|
||||
int batch_id = 0) {
|
||||
|
||||
/// Global memory mapping function
|
||||
GlobalDataLayout gmem_map_func;
|
||||
|
||||
// Construct shared memory streams
|
||||
SharedStoreStreamD shared_store_stream(
|
||||
params.shared_store_stream_d,
|
||||
shared_storage.reference());
|
||||
|
||||
SharedLoadStreamD shared_load_stream(
|
||||
params.shared_load_stream_d,
|
||||
shared_storage.reference());
|
||||
|
||||
// Map the GEMM problem dimensions into the coordinate system of the output memory
|
||||
Coord<2> gmem_bounds = gmem_map_func(make_Coord(
|
||||
problem_size.m(), // GEMM M - rows
|
||||
problem_size.n())); // GEMM N - columns
|
||||
|
||||
Coord<3> gmem_tile_bounds = make_Coord(
|
||||
problem_size.k(), // GEMM K
|
||||
gmem_bounds[0], // strided
|
||||
gmem_bounds[1]); // contiguous
|
||||
|
||||
// Iterate over the entire Threadblock tile
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
if (!(h == 0)) {
|
||||
//continue;
|
||||
}
|
||||
|
||||
// Offset in GEMM coordinates
|
||||
gemm::GemmCoord offset_in_gemm = threadblock_offset + global_offset(make_Coord(h, w));
|
||||
|
||||
Coord<2> offset_in_memory = gmem_map_func(
|
||||
make_Coord(
|
||||
offset_in_gemm.m(), // GEMM M - rows
|
||||
offset_in_gemm.n())); // GEMM N - columns
|
||||
|
||||
// Offset in
|
||||
Coord<3> global_tile_offset = make_Coord(
|
||||
offset_in_gemm.k(), // GEMM K
|
||||
offset_in_memory[0], // strided
|
||||
offset_in_memory[1]); // contiguous
|
||||
|
||||
GlobalLoadStreamC global_load_stream(
|
||||
params.load_stream_c,
|
||||
gmem_tile_bounds,
|
||||
global_tile_offset);
|
||||
|
||||
GlobalStoreStreamD global_store_stream(
|
||||
params.store_stream_d,
|
||||
gmem_tile_bounds,
|
||||
global_tile_offset);
|
||||
|
||||
// update C pointer offset based on batch_id and batch_stride_offset
|
||||
global_load_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_C);
|
||||
|
||||
// update D pointer offset based on batch_id and batch_stride_offset
|
||||
global_store_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_D);
|
||||
|
||||
// Load the C matrix into fragment.
|
||||
if (kSourceRequired) {
|
||||
global_load_stream.copy();
|
||||
}
|
||||
|
||||
// Make sure we can write to shared memory.
|
||||
shared_load_fence();
|
||||
|
||||
// Store accumulator tile to shared memory
|
||||
shared_store_stream.copy(
|
||||
select_accumulators(accumulators, make_Coord(h, w)));
|
||||
|
||||
shared_store_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
shared_store_fence();
|
||||
|
||||
// Load the accumulators back to registers from shared memory.
|
||||
shared_load_stream.copy();
|
||||
shared_load_stream.commit();
|
||||
// Commit the C matrix fragment
|
||||
if (kSourceRequired) {
|
||||
global_load_stream.commit();
|
||||
}
|
||||
|
||||
// Apply epilogue functor
|
||||
if (kSourceRequired) {
|
||||
|
||||
functor.evaluate(shared_load_stream.fragment(),
|
||||
global_load_stream.fragment(),
|
||||
global_store_stream.fragment());
|
||||
}
|
||||
else {
|
||||
|
||||
functor.evaluate(
|
||||
shared_load_stream.fragment(),
|
||||
global_store_stream.fragment());
|
||||
}
|
||||
|
||||
global_store_stream.copy();
|
||||
global_store_stream.commit();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
|
@ -0,0 +1,360 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing to shared memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory
|
||||
template <
|
||||
/// Identifies multiplicand
|
||||
GemmOperand::Kind Operand,
|
||||
/// Layout of source matrix in global memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Iterator for loading threadblock-scoped tiles
|
||||
typename LoadIterator_,
|
||||
/// Transformation functor for transforming fragments
|
||||
typename Transformer_,
|
||||
/// Iterator for storing threadblock-scoped tiles to shared memory
|
||||
typename StoreIterator_,
|
||||
/// Number of stores before iterator wraps - zero indicates no wrapping
|
||||
int StageCount>
|
||||
struct MMAGlobalLoadStream {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Identifies the operand
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = Layout;
|
||||
/// The load iterator.
|
||||
typedef LoadIterator_ LoadIterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
/// The store iterator to write to shared memory.
|
||||
typedef StoreIterator_ StoreIterator;
|
||||
/// Number of stages
|
||||
static int const kStageCount = StageCount;
|
||||
|
||||
/// Predicate vector
|
||||
typedef typename LoadIterator::PredicateVector PredicateVector;
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename LoadIterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
/// Make sure the transformed fragment is the same as the store fragment.
|
||||
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
|
||||
"");
|
||||
|
||||
/// The scalar type of the iterator.
|
||||
typedef typename LoadIterator::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename LoadIterator::Pointer Pointer;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::Index Index;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::LongIndex LongIndex;
|
||||
/// The tile.
|
||||
typedef typename LoadIterator::Tile Tile;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
|
||||
/// Helper
|
||||
static int const kElementsPerLdg = LoadIterator::Tile::kC;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The load iterator.
|
||||
typename LoadIterator::Params load_iterator;
|
||||
|
||||
/// Stride within a batch of matrix operands
|
||||
LongIndex batch_stride;
|
||||
|
||||
// Offset to residue.
|
||||
Index offset_to_residue;
|
||||
|
||||
// Offset to residue for the last partition
|
||||
Index offset_to_residue_last_partition;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): batch_stride(0), offset_to_residue(0), offset_to_residue_last_partition(0) {}
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
TensorRef<half const, 2> const &ref,
|
||||
Index _offset_to_residue
|
||||
):
|
||||
batch_stride(0),
|
||||
offset_to_residue(_offset_to_residue),
|
||||
offset_to_residue_last_partition(0),
|
||||
load_iterator(
|
||||
TensorRef<half const, 4>(
|
||||
ref.data(),
|
||||
make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1)
|
||||
)
|
||||
) {}
|
||||
|
||||
/// Initializer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(
|
||||
TensorRef<half const, 2> const &ref,
|
||||
LongIndex batch_stride_,
|
||||
Index offset_to_residue_,
|
||||
Index offset_to_residue_last_partition_) {
|
||||
|
||||
batch_stride = batch_stride_;
|
||||
offset_to_residue = offset_to_residue_;
|
||||
offset_to_residue_last_partition = offset_to_residue_last_partition_;
|
||||
|
||||
return load_iterator.initialize(
|
||||
TensorRef<half const, 4>(
|
||||
ref.data(),
|
||||
make_Coord(static_cast<int>(batch_stride), ref.stride(0), kElementsPerLdg, 1)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(
|
||||
TensorRef<half const, 2> const &ref,
|
||||
Index offset_to_residue_) {
|
||||
|
||||
offset_to_residue = offset_to_residue_;
|
||||
return load_iterator.initialize(
|
||||
TensorRef<half const, 4>(
|
||||
ref.data(),
|
||||
make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE int initialize(Index offset_to_residue_) {
|
||||
offset_to_residue = offset_to_residue_;
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE Index get_offset_to_residue() {
|
||||
if (blockIdx.z == gridDim.z - 1) { //last partition
|
||||
return offset_to_residue_last_partition;
|
||||
}
|
||||
else {
|
||||
return offset_to_residue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Empty shared storage
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Shared memory allocation for the tile
|
||||
typedef TileAllocation<
|
||||
typename StoreIterator::Scalar,
|
||||
typename ShapeMul<
|
||||
typename StoreIterator::OperandShape,
|
||||
Shape<kStageCount, 1, 1, 1>
|
||||
>::Shape
|
||||
> ThreadblockTileStorage;
|
||||
|
||||
/// ZipTensorRef to threadblock tiles
|
||||
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
///! The parameters
|
||||
Params params;
|
||||
|
||||
///! Dimensions of global memory tile
|
||||
Coord<3> threadblock_offset;
|
||||
|
||||
///! Dimensions of multiplicand bounds
|
||||
Coord<3> multiplicand_bounds;
|
||||
|
||||
///! Iterator to load threadblock tiles from global memory
|
||||
LoadIterator load_iterator;
|
||||
|
||||
///! Predicate vector
|
||||
PredicateVector predicates;
|
||||
|
||||
///! The fragment to fetch from shared memory.
|
||||
FetchedFragment fetched_fragment;
|
||||
|
||||
///! Functor to transform fragments after they have been loaded
|
||||
Transformer transformer;
|
||||
|
||||
///! The fragment to convert the data after it has been fetched from shared memory.
|
||||
TransformedFragment transformed_fragment;
|
||||
|
||||
///! Iterator to store threadblock tiles to shared memory
|
||||
StoreIterator store_iterator;
|
||||
|
||||
///! Counter
|
||||
int stage_index;
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
//
|
||||
|
||||
/// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project_coordinate(Coord<3> const &coord, Index d_offset = 0) {
|
||||
bool const kKstrided =
|
||||
gemm::GemmMultiplicandTraits<typename LoadIterator::Tile, kOperand, kLayout>::kKstrided;
|
||||
|
||||
Coord<3> tile_coord = gemm::ProjectOperand<kOperand, kKstrided>::project(coord);
|
||||
|
||||
return make_Coord(
|
||||
tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
|
||||
}
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE MMAGlobalLoadStream(Params const &_params,
|
||||
SharedStorage &shared_storage,
|
||||
ThreadblockTileRef const &threadblock_tile_ref,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const &block)
|
||||
: params(_params),
|
||||
threadblock_offset(project_coordinate(block)),
|
||||
multiplicand_bounds(project_coordinate(bounds, 1)),
|
||||
load_iterator(params.load_iterator, threadblock_offset),
|
||||
transformer(),
|
||||
store_iterator(threadblock_tile_ref.data()),
|
||||
stage_index(0) {
|
||||
load_iterator.initialize_predicates(
|
||||
predicates.begin(), multiplicand_bounds, threadblock_offset);
|
||||
}
|
||||
|
||||
/// Loads the data from global memory
|
||||
CUTLASS_DEVICE void copy() {
|
||||
load_iterator.load_post_increment(fetched_fragment, predicates.begin());
|
||||
}
|
||||
|
||||
/// Transform and commit the data to shared memory
|
||||
CUTLASS_DEVICE void commit() {
|
||||
transformer.transform(fetched_fragment, transformed_fragment);
|
||||
store_iterator.store_post_increment(transformed_fragment);
|
||||
|
||||
++stage_index;
|
||||
if (kStageCount && stage_index == kStageCount) {
|
||||
store_iterator -= kStageCount;
|
||||
stage_index = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a predicate mask for loads during final threadblock tile load iteration
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
// That's the residue!
|
||||
Coord<3> _block_offset = threadblock_offset;
|
||||
if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) {
|
||||
// K-strided
|
||||
_block_offset =
|
||||
make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]);
|
||||
} else {
|
||||
// K-contiguous
|
||||
_block_offset = make_Coord(threadblock_offset[0],
|
||||
threadblock_offset[1],
|
||||
multiplicand_bounds[2] - k / LoadIterator::Tile::kC);
|
||||
}
|
||||
|
||||
load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset);
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
|
||||
/// Move to the residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
|
||||
Index kResidue = k % kTileK;
|
||||
if (kResidue) {
|
||||
residue(kResidue);
|
||||
Index this_offset_residue = params.get_offset_to_residue();
|
||||
load_iterator.add_pointer_offset(this_offset_residue * load_iterator.stride_advance());
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to the beginning of the first tile
|
||||
CUTLASS_DEVICE void rollback(void) {
|
||||
load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, threadblock_offset);
|
||||
|
||||
int const kBlock = kOperand == GemmOperand::kA
|
||||
? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
|
||||
: (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
|
||||
Index this_offset_residue = params.get_offset_to_residue();
|
||||
load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) *
|
||||
load_iterator.stride_advance());
|
||||
}
|
||||
|
||||
/// Adds a Coord<3> to the underlying global load iterator
|
||||
CUTLASS_DEVICE MMAGlobalLoadStream &operator+=(Coord<3> const &offset) {
|
||||
load_iterator += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Adds an offset based on batch stride
|
||||
CUTLASS_DEVICE MMAGlobalLoadStream &add_batch_offset(int batch_id) {
|
||||
load_iterator.add_pointer_offset(batch_id * params.batch_stride);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
|
@ -0,0 +1,201 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Iterators used to load multiplicands from global memory specialized for Volta884 access patterns
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator for loading data for congruous access patterns
|
||||
template <GemmOperand::Kind Operand, typename Tile_, int WarpCount, int WarpDelta>
|
||||
struct MMAThreadblockCongruousLoad {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout =
|
||||
(Operand == GemmOperand::kA ? MatrixLayout::kColumnMajor : MatrixLayout::kRowMajor);
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
static int const kWarpDelta = WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename gemm::GemmMultiplicandTraits<Tile_, Operand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Reshapes the threadblock tile by access size
|
||||
typedef typename ReshapeTile<OperandShape, kAccessSize>::Tile VectorizedShape;
|
||||
|
||||
/// Shape of tile
|
||||
typedef Shape<1, 4, 8> WarpStoreCoverage;
|
||||
|
||||
/// Shape of tile loaded by each warp per load operation
|
||||
typedef Shape<1, 4, 8> WarpLoadShape;
|
||||
|
||||
//
|
||||
// Load iterator
|
||||
//
|
||||
|
||||
///
|
||||
typedef Shape<1, WarpLoadShape::kH * kWarpCount, WarpLoadShape::kW> Delta;
|
||||
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// Rakes warps along contiguous dimensions and strip-mines strided
|
||||
/// dimension.
|
||||
typedef Shape<1,
|
||||
VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount,
|
||||
VectorizedShape::kW / WarpStoreCoverage::kW,
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Functor computing starting offset for each thread
|
||||
struct ThreadOffset {
|
||||
__device__ Coord<4> operator()() const {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
|
||||
int lane_k = lane_id / WarpLoadShape::kW;
|
||||
int lane_outer = lane_id % WarpLoadShape::kW;
|
||||
|
||||
Coord<4> offset = make_Coord(0, warp_id * WarpLoadShape::kH + lane_k, lane_outer, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> LoadTileTraits;
|
||||
|
||||
/// Load iterator
|
||||
typedef TileLoadIterator<LoadTileTraits, half, IteratorAdvance::kH> Iterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator for loading data for congruous access patterns
|
||||
template <GemmOperand::Kind Operand, typename Tile_, int WarpCount, int WarpDelta>
|
||||
struct MMAThreadblockCrosswiseLoad {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout =
|
||||
(Operand == GemmOperand::kA ? MatrixLayout::kRowMajor : MatrixLayout::kColumnMajor);
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
static int const kWarpDelta = WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename gemm::GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Reshapes the threadblock tile by access size
|
||||
typedef typename ReshapeTile<OperandShape, kAccessSize>::Tile VectorizedShape;
|
||||
|
||||
/// Shape of tile
|
||||
typedef Shape<1, 8, 4> WarpStoreCoverage;
|
||||
|
||||
/// Shape of tile loaded by each warp per load operation
|
||||
typedef Shape<1, 8, 4> WarpLoadShape;
|
||||
|
||||
//
|
||||
// Load iterator
|
||||
//
|
||||
|
||||
///
|
||||
typedef Shape<1, WarpLoadShape::kH, WarpLoadShape::kW> Delta;
|
||||
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// Rakes warps along contiguous dimensions and strip-mines strided
|
||||
/// dimension.
|
||||
typedef Shape<1,
|
||||
VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount,
|
||||
VectorizedShape::kW / WarpStoreCoverage::kW,
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Functor computing starting offset for each thread
|
||||
struct ThreadOffset {
|
||||
__device__ Coord<4> operator()() const {
|
||||
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
|
||||
int lane_k = lane_id % WarpLoadShape::kW;
|
||||
int lane_outer = lane_id / WarpLoadShape::kW;
|
||||
|
||||
Coord<4> offset =
|
||||
make_Coord(0, warp_id * Iterations::kH * WarpLoadShape::kH + lane_outer, lane_k, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> LoadTileTraits;
|
||||
|
||||
/// Load iterator
|
||||
typedef TileLoadIterator<LoadTileTraits, half, IteratorAdvance::kW> Iterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // gemm
|
||||
} // namespace cutlass
|
|
@ -0,0 +1,155 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing to shared memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename Iterator_,
|
||||
/// The transformer to be applied after the data has been copied from shared memory.
|
||||
typename Transformer_ = Copy<typename Iterator_::Fragment>,
|
||||
/// Number of increments before iterator wraps - zero indicates no wrapping
|
||||
int StageCount = 1>
|
||||
struct MMASharedLoadStream {
|
||||
/// The load iterator.
|
||||
typedef Iterator_ Iterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
|
||||
/// Number of increments before iterator wraps - zero indicates no wrapping
|
||||
static int const kStageCount = StageCount;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename Iterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
|
||||
/// Element type
|
||||
typedef typename Iterator::Scalar Scalar;
|
||||
|
||||
/// Reference type to a tensor
|
||||
typedef TensorRef<half, 4> TensorRef;
|
||||
|
||||
/// Parameters passed from host
|
||||
struct Params {};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator for loading fragments for warp-level matrix multiply-accumulate
|
||||
Iterator iterator;
|
||||
|
||||
/// Fetched fragment
|
||||
FetchedFragment fetched[2];
|
||||
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
|
||||
/// Transformed fragment
|
||||
TransformedFragment transformed[2];
|
||||
|
||||
/// Counts the number of stages
|
||||
int stage_index;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE MMASharedLoadStream() : stage_index(0) {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE MMASharedLoadStream(
|
||||
Params const &_params,
|
||||
TensorRef const &ref,
|
||||
Coord<4> warp_offset = make_Coord(0, 0, 0, 0)
|
||||
):
|
||||
iterator(ref.data(), warp_offset), stage_index(0) {
|
||||
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(int step) {
|
||||
iterator.load(
|
||||
fetched[step % 2],
|
||||
make_Coord(step + stage_index * Iterator::VectorizedShape::kD, 0, 0, 0)
|
||||
);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
transformer.transform(fetched[step % 2], transformed[step % 2]);
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE void clear() {
|
||||
fetched[0].clear();
|
||||
fetched[1].clear();
|
||||
transformed[0].clear();
|
||||
transformed[1].clear();
|
||||
}
|
||||
|
||||
/// Gets the transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment &fragment(int step) { return transformed[step % 2]; }
|
||||
|
||||
/// Gets the transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment const &fragment(int step) const { return transformed[step % 2]; }
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
|
||||
++stage_index;
|
||||
if (kStageCount && stage_index == StageCount) {
|
||||
stage_index = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // gemm
|
||||
} // namespace cutlass
|
|
@ -1,6 +1,5 @@
|
|||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -73,16 +73,27 @@ struct ThreadMultiplyAdd {
|
|||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
if(kLayout_ == MatrixLayout::kColumnMajor) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
|
||||
d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
|
||||
d[i * AccumulatorsPerThread::kH + j] = a[i] * b[j] + c[i * AccumulatorsPerThread::kH + j];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -0,0 +1,348 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with the computed matrix product.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/zip_fragment.h"
|
||||
#include "cutlass/zip_tile_iterator.h"
|
||||
#include "cutlass/util/complex.h"
|
||||
#include "cutlass/gemm/volta884_gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/split_complex_linear_scaling.h"
|
||||
#include "cutlass/util/pair.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Enables treating the accumulators selection as one object
|
||||
template <typename First_, typename Second_>
|
||||
struct ZipSelectAccumulators {
|
||||
|
||||
/// Underlying selection function
|
||||
typedef First_ First;
|
||||
typedef Second_ Second;
|
||||
|
||||
/// Accumulators
|
||||
typedef ZipFragment<
|
||||
typename First::Accumulators,
|
||||
typename Second::Accumulators> Accumulators;
|
||||
|
||||
/// Fragment
|
||||
typedef ZipFragment<
|
||||
typename First::Fragment,
|
||||
typename Second::Fragment> Fragment;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Selects the accumulators for the first part
|
||||
First first;
|
||||
|
||||
/// Selects the accumulators for the second
|
||||
Second second;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_DEVICE
|
||||
ZipSelectAccumulators() { }
|
||||
|
||||
/// Basic constructor
|
||||
CUTLASS_DEVICE
|
||||
ZipSelectAccumulators(First const &_first, Second const &_second): first(_first), second(_second) { }
|
||||
|
||||
/// Selects accumulators for a given iteration of the epilogue
|
||||
CUTLASS_DEVICE
|
||||
Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
|
||||
return make_ZipFragment(first(accum.first, idx), second(accum.second, idx));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines epilogue traits for complex-valued mma.sync GEMM
|
||||
template <
|
||||
typename GemmConfig_,
|
||||
typename EpilogueFunctor_ = SplitComplexLinearScaling<typename GemmConfig_::MultiplyAdd::ScalarC>,
|
||||
typename Index_ = int>
|
||||
struct Volta884ComplexGemmEpilogueTraits {
|
||||
|
||||
/// GEMM configuration
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
|
||||
/// Epilogue functor
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// Global memory mapping function
|
||||
typedef MatrixLayout::ColumnMajor GlobalDataLayout;
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Long index used for offsets
|
||||
typedef long long LongIndex;
|
||||
|
||||
/// Defines epilogue traits for real-valued Volta884 GEMM epilogue
|
||||
typedef typename Volta884GemmEpilogueTraitsHelper<
|
||||
GemmConfig,
|
||||
Functor,
|
||||
typename GemmConfig::MultiplyAdd::RealMultiplyAdd,
|
||||
Index>::EpilogueTraits RealEpilogueTraits;
|
||||
|
||||
/// The output tile.
|
||||
typedef typename RealEpilogueTraits::OutputTile OutputTile;
|
||||
|
||||
/// The warp-level GEMM tile
|
||||
typedef typename RealEpilogueTraits::WarpGemmTile WarpGemmTile;
|
||||
|
||||
/// Tiling of warp accumulator elements
|
||||
typedef typename RealEpilogueTraits::WarpGemmTile WarpDelta;
|
||||
|
||||
/// Multiply-add operation
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
|
||||
/// The accumulators fragment type.
|
||||
typedef typename MultiplyAdd::Accumulators Accumulators;
|
||||
|
||||
/// Selects a subset of accumulators for a given epilogue iteration
|
||||
typedef ZipSelectAccumulators<
|
||||
typename RealEpilogueTraits::SelectAccumulators,
|
||||
typename RealEpilogueTraits::SelectAccumulators> SelectAccumulators;
|
||||
|
||||
/// The iterator to load source matrix from global memory.
|
||||
typedef cutlass::PredicatedTileLoadStream<
|
||||
ZipTileIterator<
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::Iterator,
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::Iterator
|
||||
>,
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::PredicateFunctor,
|
||||
ZipConvert<
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::Transformer,
|
||||
typename RealEpilogueTraits::GlobalLoadStreamC::Transformer
|
||||
>
|
||||
> GlobalLoadStreamC;
|
||||
|
||||
/// The iterator to store the final GEMM computation to global memory.
|
||||
typedef cutlass::PredicatedTileStoreStream<
|
||||
ZipTileIterator<
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::Iterator,
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::Iterator
|
||||
>,
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::PredicateFunctor,
|
||||
ZipConvert<
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::Transformer,
|
||||
typename RealEpilogueTraits::GlobalStoreStreamD::Transformer
|
||||
>
|
||||
> GlobalStoreStreamD;
|
||||
|
||||
/// The stream to store matrix product to shared memory
|
||||
typedef cutlass::TileStoreStream<
|
||||
ZipTileIterator<
|
||||
typename RealEpilogueTraits::SharedStoreStreamD::Iterator,
|
||||
typename RealEpilogueTraits::SharedStoreStreamD::Iterator
|
||||
>,
|
||||
ZipConvert<
|
||||
typename RealEpilogueTraits::SharedStoreStreamD::Transformer,
|
||||
typename RealEpilogueTraits::SharedStoreStreamD::Transformer
|
||||
>
|
||||
> SharedStoreStreamD;
|
||||
|
||||
/// The stream to load the matrix product from shared memory
|
||||
typedef cutlass::TileLoadStream<
|
||||
ZipTileIterator<
|
||||
typename RealEpilogueTraits::SharedLoadStreamD::Iterator,
|
||||
typename RealEpilogueTraits::SharedLoadStreamD::Iterator
|
||||
>,
|
||||
ZipConvert<
|
||||
typename RealEpilogueTraits::SharedLoadStreamD::Transformer,
|
||||
typename RealEpilogueTraits::SharedLoadStreamD::Transformer
|
||||
>
|
||||
> SharedLoadStreamD;
|
||||
|
||||
/// The scalar type of the source accumulator matrix.
|
||||
typedef typename RealEpilogueTraits::ScalarC ScalarC;
|
||||
|
||||
/// The scalar type of the destination accumulator matrix.
|
||||
typedef typename RealEpilogueTraits::ScalarD ScalarD;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Cover an entire warp-level tile
|
||||
typedef typename RealEpilogueTraits::Iterations Iterations;
|
||||
|
||||
/// Parameters structure initialized on the host
|
||||
struct Params {
|
||||
/// The params for the C iterator.
|
||||
typename GlobalLoadStreamC::Params load_stream_c;
|
||||
|
||||
/// The params for the D global iterator.
|
||||
typename GlobalStoreStreamD::Params store_stream_d;
|
||||
|
||||
/// Epilogue functor params
|
||||
typename Functor::Params functor;
|
||||
|
||||
/// The params for the D shared store iterator.
|
||||
typename SharedStoreStreamD::Params shared_store_stream_d;
|
||||
|
||||
/// The params for the D shared load stream.
|
||||
typename SharedLoadStreamD::Params shared_load_stream_d;
|
||||
|
||||
/// Stride for C
|
||||
platform::Pair<LongIndex, LongIndex> batch_stride_C;
|
||||
|
||||
/// Stride for D
|
||||
platform::Pair<LongIndex, LongIndex> batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {
|
||||
batch_stride_C.first = 0;
|
||||
batch_stride_C.second = 0;
|
||||
|
||||
batch_stride_D.first = 0;
|
||||
batch_stride_D.second = 0;
|
||||
}
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
platform::complex<typename Functor::Scalar> alpha,
|
||||
platform::complex<typename Functor::Scalar> beta,
|
||||
ScalarC const* real_C,
|
||||
Index real_ldc,
|
||||
ScalarC const* imag_C,
|
||||
Index imag_ldc,
|
||||
ScalarD* real_D,
|
||||
Index real_ldd,
|
||||
ScalarD* imag_D,
|
||||
Index imag_ldd) {
|
||||
|
||||
int result = functor.initialize(alpha, beta);
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for C.
|
||||
result = load_stream_c.iterator.first.initialize(
|
||||
real_C, real_ldc, real_ldc, 1);
|
||||
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = load_stream_c.iterator.second.initialize(
|
||||
imag_C, imag_ldc, imag_ldc, 1);
|
||||
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for D.
|
||||
result = store_stream_d.iterator.first.initialize(
|
||||
real_D, real_ldd, real_ldd, 1);
|
||||
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = store_stream_d.iterator.second.initialize(
|
||||
imag_D, imag_ldd, imag_ldd, 1);
|
||||
|
||||
if (result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
platform::complex<typename Functor::Scalar> alpha,
|
||||
platform::complex<typename Functor::Scalar> beta,
|
||||
ScalarC const* real_C,
|
||||
Index real_ldc,
|
||||
LongIndex stride_C_real,
|
||||
ScalarC const* imag_C,
|
||||
Index imag_ldc,
|
||||
LongIndex stride_C_imag,
|
||||
ScalarD* real_D,
|
||||
Index real_ldd,
|
||||
LongIndex stride_D_real,
|
||||
ScalarD* imag_D,
|
||||
Index imag_ldd,
|
||||
LongIndex stride_D_imag) {
|
||||
|
||||
batch_stride_C.first = stride_C_real;
|
||||
batch_stride_C.second = stride_C_imag;
|
||||
|
||||
batch_stride_D.first = stride_D_real;
|
||||
batch_stride_D.second = stride_D_imag;
|
||||
|
||||
return initialize(alpha, beta, real_C, real_ldc, imag_C, imag_ldc, real_D, real_ldd, imag_D, imag_ldd);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory buffer used by epilogue
|
||||
typedef ZipTileAllocation<
|
||||
typename RealEpilogueTraits::SharedStorage,
|
||||
typename RealEpilogueTraits::SharedStorage> SharedStorage;
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue.
|
||||
typedef typename RealEpilogueTraits::GlobalOffset GlobalOffset;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
|
||||
namespace platform {
|
||||
|
||||
/// Here's a helpful arithmetic operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
Pair<long long, long long> operator*(int s, Pair<long long, long long> _pair) {
|
||||
return Pair<long long, long long>(s * _pair.first, s * _pair.second);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
|
@ -0,0 +1,558 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for complex-valued GEMM targeting Volta's mma.sync
|
||||
instruction.
|
||||
|
||||
At present, it expects split complex representation in global memory in which the real part and
|
||||
imaginary parts of a complex-valued matrices are disjoint (a structure of arrays). This is in
|
||||
contrast with an interleaved complex representation which is an array of structures.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/gemm/clear_accumulators.h"
|
||||
#include "cutlass/gemm/gemm_config.h"
|
||||
#include "cutlass/gemm/gemm_stream_pair.h"
|
||||
#include "cutlass/gemm/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
#include "cutlass/kernel_launch.h"
|
||||
#include "cutlass/tensor_ref_collection.h"
|
||||
|
||||
#include "cutlass/gemm/gemm_desc.h"
|
||||
|
||||
#include "cutlass/gemm/volta884_multiplicand.h"
|
||||
#include "cutlass/gemm/mma_shared_stream.h"
|
||||
#include "cutlass/gemm/volta884_gemm_traits.h"
|
||||
|
||||
#include "cutlass/gemm/volta884_complex_multiply_add.h"
|
||||
#include "cutlass/gemm/volta884_complex_global_stream.h"
|
||||
#include "cutlass/gemm/volta884_complex_shared_stream.h"
|
||||
#include "cutlass/gemm/volta884_complex_gemm_epilogue_traits.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines configuration for Volta884 GEMM
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Indicates matrix transform on multiplicand A
|
||||
MatrixTransform::Kind TransformA,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
MatrixTransform::Kind TransformB,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The source matrix type type.
|
||||
typename ScalarC_,
|
||||
/// The destination matrix type
|
||||
typename ScalarD_,
|
||||
/// Number of stages in shared memory
|
||||
int StageCount,
|
||||
/// Enables or disables launch bounds
|
||||
bool LaunchBounds>
|
||||
struct Volta884ComplexGemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
ScalarC_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The threadblock tile size
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
Volta884ComplexMultiplyAdd<WarpGemmShape_,
|
||||
LayoutA,
|
||||
TransformA,
|
||||
half,
|
||||
LayoutB,
|
||||
TransformB,
|
||||
half,
|
||||
Accumulator_>,
|
||||
/// The number of scalars per LDG for A.
|
||||
8,
|
||||
/// The number of scalars per STS for A.
|
||||
8,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
8,
|
||||
/// The number of scalars per STS for B.
|
||||
8,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of scalars per STS for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of scalars per LDS for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of stages in shared memory.
|
||||
StageCount,
|
||||
/// If true, separate mainloop is instantiated
|
||||
true,
|
||||
/// If true, compute residue in prolog
|
||||
false,
|
||||
/// Launch bounds not used
|
||||
LaunchBounds> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines components of Volta884 GEMM
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Indicates matrix transform on multiplicand A
|
||||
MatrixTransform::Kind TransformA,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
MatrixTransform::Kind TransformB,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The input matrix type type.
|
||||
typename ScalarC_,
|
||||
/// The output matrix type type.
|
||||
typename ScalarD_,
|
||||
/// Number of buffers in shared memory to use
|
||||
int StageCount,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = SplitComplexLinearScaling<Accumulator_>,
|
||||
/// Enables or disables launch bounds
|
||||
bool LaunchBounds = false
|
||||
>
|
||||
struct Volta884ComplexGemmTraits {
|
||||
|
||||
/// This is insane.
|
||||
typedef Volta884ComplexGemmTraits<
|
||||
LayoutA,
|
||||
TransformA,
|
||||
LayoutB,
|
||||
TransformB,
|
||||
OutputTile_,
|
||||
WarpGemmShape_,
|
||||
Accumulator_,
|
||||
ScalarC_,
|
||||
ScalarD_,
|
||||
StageCount,
|
||||
EpilogueFunctor_,
|
||||
LaunchBounds> This;
|
||||
|
||||
/// The actual device-side GEMM
|
||||
typedef GemmMainloop<This> KernelClass;
|
||||
|
||||
/// Layout of multiplicand A matrix
|
||||
static MatrixLayout::Kind const kLayoutA = LayoutA;
|
||||
|
||||
/// If true, A operand is conjugated
|
||||
static MatrixTransform::Kind const kTransformA = TransformA;
|
||||
|
||||
/// Layout of multiplicand B matrix
|
||||
static MatrixLayout::Kind const kLayoutB = LayoutB;
|
||||
|
||||
/// If true, B operand is conjugated
|
||||
static MatrixTransform::Kind const kTransformB = TransformB;
|
||||
|
||||
/// Dimensions of threadblock tile (concept Shape)
|
||||
typedef OutputTile_ OutputTile;
|
||||
|
||||
/// Shape of warp-level accumulators
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Multiplicand A scalar type
|
||||
typedef half ScalarA;
|
||||
|
||||
/// Multiplicand B scalar type
|
||||
typedef half ScalarB;
|
||||
|
||||
/// Data type of internal accumulator
|
||||
typedef Accumulator_ Accumulator;
|
||||
|
||||
/// Data type of input accumulator matrix operand
|
||||
typedef ScalarC_ ScalarC;
|
||||
|
||||
/// Data type of output accumulator matrix operand
|
||||
typedef ScalarD_ ScalarD;
|
||||
|
||||
/// Shape of individual mma.sync instruction
|
||||
typedef Shape<4, 16, 16> InstructionShape;
|
||||
|
||||
/// Tile size for an individual warp-level multiply-add
|
||||
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
|
||||
|
||||
/// Defines properties about GEMM needed by host code
|
||||
typedef Volta884ComplexGemmConfig<
|
||||
kLayoutA,
|
||||
kTransformA,
|
||||
kLayoutB,
|
||||
kTransformB,
|
||||
OutputTile,
|
||||
WarpGemmShape,
|
||||
Accumulator,
|
||||
ScalarC,
|
||||
ScalarD,
|
||||
StageCount,
|
||||
LaunchBounds>
|
||||
GemmConfig;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Long index type
|
||||
typedef long long LongIndex;
|
||||
|
||||
/// Partitioning of threadblock into warps
|
||||
typedef typename ShapeDiv<OutputTile, WarpGemmShape>::Shape WarpDelta;
|
||||
|
||||
/// Number of warps per threadblock
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Defines iterators for A matrix
|
||||
typedef Volta884Multiplicand<GemmOperand::kA, kLayoutA, OutputTile, WarpTile, kWarpCount, WarpDelta>
|
||||
MultiplicandA;
|
||||
|
||||
/// Defines iterators for B matrix
|
||||
typedef Volta884Multiplicand<GemmOperand::kB, kLayoutB, OutputTile, WarpTile, kWarpCount, WarpDelta>
|
||||
MultiplicandB;
|
||||
|
||||
//
|
||||
// GemmTraits mandatory type definitions
|
||||
//
|
||||
|
||||
/// Maps hardware threadblocks to logical partitions of the GEMM
|
||||
typedef IdentityBlockSwizzle BlockSwizzle;
|
||||
|
||||
/// Clears accumulators
|
||||
typedef ClearAccumulators<ScalarC> ClearAccumulators;
|
||||
|
||||
/// Loads multiplicands from global memory
|
||||
typedef GlobalLoadStreamPair<
|
||||
Volta884ComplexGlobalLoadStream<GemmOperand::kA,
|
||||
kLayoutA,
|
||||
typename MultiplicandA::LoadIterator,
|
||||
Copy<typename MultiplicandA::LoadIterator::Fragment>,
|
||||
typename MultiplicandA::StoreIterator,
|
||||
StageCount>,
|
||||
Volta884ComplexGlobalLoadStream<GemmOperand::kB,
|
||||
kLayoutB,
|
||||
typename MultiplicandB::LoadIterator,
|
||||
Copy<typename MultiplicandB::LoadIterator::Fragment>,
|
||||
typename MultiplicandB::StoreIterator,
|
||||
StageCount>,
|
||||
GemmConfig::kResidueInProlog >
|
||||
GlobalLoadStream;
|
||||
|
||||
/// Memory needed to store the threadblock-scoped GEMM tile
|
||||
typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
|
||||
|
||||
/// Shared memory storage for mainloop phase
|
||||
union MainLoopStorage {
|
||||
|
||||
/// Stores the threadblock tile
|
||||
ThreadblockTileStorage threadblock_tile;
|
||||
|
||||
/// Storage for GEMM global stream
|
||||
typename GlobalLoadStream::SharedStorage global_to_shared_stream;
|
||||
};
|
||||
|
||||
/// Loads multiplicands from shared memory
|
||||
typedef SharedStreamPair<
|
||||
Volta884ComplexSharedLoadStream<typename MultiplicandA::WarpLoadIterator,
|
||||
Copy<typename MultiplicandA::WarpLoadIterator::Fragment>,
|
||||
StageCount>,
|
||||
Volta884ComplexSharedLoadStream<typename MultiplicandB::WarpLoadIterator,
|
||||
Copy<typename MultiplicandB::WarpLoadIterator::Fragment>,
|
||||
StageCount> >
|
||||
SharedStream;
|
||||
|
||||
// Multiply-add object specialized for Volta mma.sync
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
|
||||
#if 0
|
||||
/// Naive epilogue for updating the output matrix
|
||||
typedef Volta884ComplexNaiveEpilogue<ScalarC,
|
||||
typename MultiplicandA::WarpDelta,
|
||||
typename MultiplyAdd::Iterations>
|
||||
Epilogue;
|
||||
|
||||
#else
|
||||
|
||||
/// Efficient epilogue
|
||||
typedef MMAEpilogue<
|
||||
Volta884ComplexGemmEpilogueTraits<GemmConfig, EpilogueFunctor_>
|
||||
> Epilogue;
|
||||
|
||||
#endif
|
||||
|
||||
/// Tensor reference to A multiplicand
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<ScalarA, 2>,
|
||||
TensorRef<ScalarA, 2>
|
||||
> TensorRefA;
|
||||
|
||||
/// Tensor reference to B multiplicand
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<ScalarB, 2>,
|
||||
TensorRef<ScalarB, 2>
|
||||
> TensorRefB;
|
||||
|
||||
/// Tensor reference to C multiplicand
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<ScalarC, 2>,
|
||||
TensorRef<ScalarC, 2>
|
||||
> TensorRefC;
|
||||
|
||||
/// Tensor reference to D multiplicand
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<ScalarD, 2>,
|
||||
TensorRef<ScalarD, 2>
|
||||
> TensorRefD;
|
||||
|
||||
/// gemm::ProblemDesc<>
|
||||
typedef GemmDesc<
|
||||
TensorRefA,
|
||||
TensorRefB,
|
||||
TensorRefC,
|
||||
TensorRefD,
|
||||
float
|
||||
> GemmDesc;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params : public KernelLaunchConfiguration {
|
||||
/// The dimensions of the GEMM.
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// PartitionK_range
|
||||
int partitionK_range;
|
||||
|
||||
/// The params for the global load stream
|
||||
typename GlobalLoadStream::Params global_to_shared_stream;
|
||||
|
||||
/// The params for the shared load stream
|
||||
typename SharedStream::Params shared_stream;
|
||||
|
||||
/// The params for the epilogue.
|
||||
typename Epilogue::Params epilogue;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Initialize the Params struct
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
platform::complex<typename Epilogue::Scalar> alpha,
|
||||
ScalarA const* real_A,
|
||||
Index real_lda,
|
||||
ScalarA const* imag_A,
|
||||
Index imag_lda,
|
||||
ScalarB const* real_B,
|
||||
Index real_ldb,
|
||||
ScalarB const* imag_B,
|
||||
Index imag_ldb,
|
||||
platform::complex<typename Epilogue::Scalar> beta,
|
||||
ScalarC const* real_C,
|
||||
Index real_ldc,
|
||||
ScalarC const* imag_C,
|
||||
Index imag_ldc,
|
||||
ScalarD* real_D,
|
||||
Index real_ldd,
|
||||
ScalarD* imag_D,
|
||||
Index imag_ldd) {
|
||||
|
||||
problem_size = make_Coord(k, n, m, 1);
|
||||
|
||||
partitionK_range = problem_size.k();
|
||||
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Initialize global load streams
|
||||
global_to_shared_stream.stream_a.initialize(
|
||||
make_ZipTensorRef(
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_A, real_lda), 0),
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_A, imag_lda), 0)
|
||||
),
|
||||
0
|
||||
);
|
||||
|
||||
global_to_shared_stream.stream_b.initialize(
|
||||
make_ZipTensorRef(
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_B, real_ldb), 0),
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_B, imag_ldb), 0)
|
||||
),
|
||||
0
|
||||
);
|
||||
|
||||
return epilogue.initialize(
|
||||
alpha,
|
||||
beta,
|
||||
real_C,
|
||||
real_ldc,
|
||||
imag_C,
|
||||
imag_ldc,
|
||||
real_D,
|
||||
real_ldd,
|
||||
imag_D,
|
||||
imag_ldd
|
||||
);
|
||||
}
|
||||
|
||||
/// Initialize the Params struct
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
platform::complex<typename Epilogue::Scalar> alpha,
|
||||
ScalarA const* real_A,
|
||||
Index real_lda,
|
||||
LongIndex batch_stride_A_real,
|
||||
ScalarA const* imag_A,
|
||||
Index imag_lda,
|
||||
LongIndex batch_stride_A_imag,
|
||||
ScalarB const* real_B,
|
||||
Index real_ldb,
|
||||
LongIndex batch_stride_B_real,
|
||||
ScalarB const* imag_B,
|
||||
Index imag_ldb,
|
||||
LongIndex batch_stride_B_imag,
|
||||
platform::complex<typename Epilogue::Scalar> beta,
|
||||
ScalarC const* real_C,
|
||||
Index real_ldc,
|
||||
LongIndex batch_stride_C_real,
|
||||
ScalarC const* imag_C,
|
||||
Index imag_ldc,
|
||||
LongIndex batch_stride_C_imag,
|
||||
ScalarD* real_D,
|
||||
Index real_ldd,
|
||||
LongIndex batch_stride_D_real,
|
||||
ScalarD* imag_D,
|
||||
Index imag_ldd,
|
||||
LongIndex batch_stride_D_imag,
|
||||
int batch_count) {
|
||||
|
||||
problem_size = make_Coord(k, n, m, batch_count);
|
||||
partitionK_range = problem_size.k();
|
||||
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Initialize global load streams
|
||||
global_to_shared_stream.stream_a.initialize(
|
||||
make_ZipTensorRef(
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_A, real_lda), batch_stride_A_real),
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_A, imag_lda), batch_stride_A_imag)
|
||||
),
|
||||
0
|
||||
);
|
||||
|
||||
global_to_shared_stream.stream_b.initialize(
|
||||
make_ZipTensorRef(
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_B, real_ldb), batch_stride_B_real),
|
||||
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_B, imag_ldb), batch_stride_B_imag)
|
||||
),
|
||||
0
|
||||
);
|
||||
|
||||
return epilogue.initialize(
|
||||
alpha,
|
||||
beta,
|
||||
real_C,
|
||||
real_ldc,
|
||||
batch_stride_C_real,
|
||||
imag_C,
|
||||
imag_ldc,
|
||||
batch_stride_C_imag,
|
||||
real_D,
|
||||
real_ldd,
|
||||
batch_stride_D_real,
|
||||
imag_D,
|
||||
imag_ldd,
|
||||
batch_stride_D_imag
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage
|
||||
union SharedStorage {
|
||||
/// Storage required during mainloop phase
|
||||
MainLoopStorage main_loop;
|
||||
|
||||
/// Shared storage needed for epilogue
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
|
||||
if (StageCount < 2) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
|
@ -0,0 +1,315 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing
|
||||
to shared memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/zip_tile_iterator.h"
|
||||
#include "cutlass/zip_tensor_ref.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/util/pair.h"
|
||||
|
||||
#include "cutlass/gemm/mma_global_stream.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory
|
||||
template <
|
||||
/// Identifies multiplicand
|
||||
GemmOperand::Kind Operand,
|
||||
/// Layout of source matrix in global memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Iterator for loading threadblock-scoped tiles
|
||||
typename LoadIterator_,
|
||||
/// Transformation functor for transforming fragments
|
||||
typename Transformer_,
|
||||
/// Iterator for storing threadblock-scoped tiles to shared memory
|
||||
typename StoreIterator_,
|
||||
/// Number of stores before iterator wraps - zero indicates no wrapping
|
||||
int StageCount>
|
||||
struct Volta884ComplexGlobalLoadStream {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Identifies the operand
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = Layout;
|
||||
|
||||
/// Load-store stream for real-valued matrices
|
||||
typedef MMAGlobalLoadStream<Operand, Layout, LoadIterator_, Transformer_, StoreIterator_, StageCount> RealLoadStoreStream;
|
||||
|
||||
/// Loads a pair of real-valued fragments
|
||||
typedef ZipTileIterator<LoadIterator_, LoadIterator_> LoadIterator;
|
||||
|
||||
/// Zips a pair of transformers
|
||||
typedef ZipConvert<Transformer_, Transformer_> Transformer;
|
||||
|
||||
/// Stores a pair of real-valued ragments
|
||||
typedef ZipTileIterator<StoreIterator_, StoreIterator_> StoreIterator;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStageCount = StageCount;
|
||||
|
||||
/// Predicate vector
|
||||
typedef typename RealLoadStoreStream::PredicateVector PredicateVector;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename LoadIterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
/// Make sure the transformed fragment is the same as the store fragment.
|
||||
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
|
||||
"");
|
||||
|
||||
/// Index type
|
||||
typedef typename RealLoadStoreStream::Index Index;
|
||||
|
||||
/// Long index type
|
||||
typedef typename RealLoadStoreStream::LongIndex LongIndex;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Matrix reference
|
||||
typedef ZipTensorRef<
|
||||
TensorRefBatchStrided<half const, 2>,
|
||||
TensorRefBatchStrided<half const, 2> > SourceTensorRef;
|
||||
|
||||
/// Helper
|
||||
static int const kElementsPerLdg = LoadIterator::First::Tile::kC;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Source tensor reference
|
||||
platform::Pair<LongIndex, LongIndex> batch_stride;
|
||||
|
||||
// The load iterator.
|
||||
typename LoadIterator::Params load_iterator;
|
||||
|
||||
// Offset to residue.
|
||||
Index offset_to_residue;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(SourceTensorRef const &ref, Index _offset_to_residue) {
|
||||
initialize(ref, _offset_to_residue);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(SourceTensorRef const &ref, Index _offset_to_residue) {
|
||||
|
||||
batch_stride.first = ref.first.tensor_stride;
|
||||
batch_stride.second = ref.second.tensor_stride;
|
||||
|
||||
offset_to_residue = _offset_to_residue;
|
||||
load_iterator.first.initialize(
|
||||
TensorRef<half const, 4>(
|
||||
ref.first.at().data(),
|
||||
make_Coord(ref.first.at().stride(0) * kElementsPerLdg, ref.first.at().stride(0), kElementsPerLdg)
|
||||
)
|
||||
);
|
||||
load_iterator.second.initialize(
|
||||
TensorRef<half const, 4>(
|
||||
ref.second.at().data(),
|
||||
make_Coord(ref.second.at().stride(0) * kElementsPerLdg, ref.second.at().stride(0), kElementsPerLdg)
|
||||
)
|
||||
);
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Empty shared storage
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Shared memory allocation for the tile
|
||||
typedef TileAllocation<
|
||||
typename RealLoadStoreStream::StoreIterator::Scalar,
|
||||
typename ShapeMul<
|
||||
typename RealLoadStoreStream::StoreIterator::OperandShape,
|
||||
Shape<kStageCount, 1, 1, 1>
|
||||
>::Shape
|
||||
> RealThreadblockTileStorage;
|
||||
|
||||
/// Threadblock tile allocation
|
||||
typedef ZipTileAllocation<
|
||||
RealThreadblockTileStorage,
|
||||
RealThreadblockTileStorage
|
||||
> ThreadblockTileStorage;
|
||||
|
||||
/// Reference to ThreadblockTileStorage
|
||||
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
///! The parameters
|
||||
Params params;
|
||||
|
||||
///! Dimensions of global memory tile
|
||||
Coord<3> threadblock_offset;
|
||||
|
||||
///! Multiplicand bounds
|
||||
Coord<3> multiplicand_bounds;
|
||||
|
||||
///! Iterator to load threadblock tiles from global memory
|
||||
LoadIterator load_iterator;
|
||||
|
||||
///! Predicate vector
|
||||
PredicateVector predicates;
|
||||
|
||||
///! The fragment to fetch from shared memory.
|
||||
FetchedFragment fetched_fragment;
|
||||
|
||||
///! Functor to transform fragments after they have been loaded
|
||||
Transformer transformer;
|
||||
|
||||
///! The fragment to convert the data after it has been fetched from shared memory.
|
||||
TransformedFragment transformed_fragment;
|
||||
|
||||
///! Iterator to store threadblock tiles to shared memory
|
||||
StoreIterator store_iterator;
|
||||
|
||||
///! Counter
|
||||
int stage_index;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE Volta884ComplexGlobalLoadStream(Params const &_params,
|
||||
SharedStorage &shared_storage,
|
||||
ThreadblockTileRef const &threadblock_tile_ref,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const &block)
|
||||
: params(_params),
|
||||
threadblock_offset(RealLoadStoreStream::project_coordinate(block)),
|
||||
multiplicand_bounds(RealLoadStoreStream::project_coordinate(bounds, 1)),
|
||||
load_iterator(params.load_iterator, threadblock_offset),
|
||||
transformer(),
|
||||
store_iterator(threadblock_tile_ref),
|
||||
stage_index(0) {
|
||||
|
||||
// initialize predicates used to guard loads
|
||||
load_iterator.initialize_predicates(
|
||||
predicates.begin(), multiplicand_bounds, threadblock_offset);
|
||||
}
|
||||
|
||||
/// Loads the data from global memory
|
||||
CUTLASS_DEVICE void copy() {
|
||||
load_iterator.load_post_increment(fetched_fragment, predicates.begin());
|
||||
}
|
||||
|
||||
/// Transform and commit the data to shared memory
|
||||
CUTLASS_DEVICE void commit() {
|
||||
|
||||
transformer.transform(fetched_fragment, transformed_fragment);
|
||||
store_iterator.store_post_increment(transformed_fragment);
|
||||
|
||||
++stage_index;
|
||||
if (kStageCount && stage_index == kStageCount) {
|
||||
store_iterator -= kStageCount;
|
||||
stage_index = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a predicate mask for loads during final threadblock tile load iteration
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
// That's the residue!
|
||||
Coord<3> _block_offset = threadblock_offset;
|
||||
if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) {
|
||||
// K-strided
|
||||
_block_offset =
|
||||
make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]);
|
||||
} else {
|
||||
// K-contiguous
|
||||
_block_offset = make_Coord(threadblock_offset[0],
|
||||
threadblock_offset[1],
|
||||
multiplicand_bounds[2] - k / LoadIterator::First::Tile::kC);
|
||||
}
|
||||
|
||||
load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset);
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {}
|
||||
|
||||
CUTLASS_DEVICE void rollback() {}
|
||||
|
||||
/// Adds a Coord<3> to the underlying global load iterator
|
||||
CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &operator+=(Coord<3> const &offset) {
|
||||
load_iterator += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Adds an offset based on batch stride
|
||||
CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &add_batch_offset(int batch_id) {
|
||||
load_iterator.first.add_pointer_offset(params.batch_stride.first * batch_id);
|
||||
load_iterator.second.add_pointer_offset(params.batch_stride.second * batch_id);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
|
@ -0,0 +1,319 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements warp-level multiply-accumulate operations using Volta's mma.sync instruction
|
||||
for complex-valued data types.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/util/complex.h"
|
||||
#include "cutlass/zip_fragment.h"
|
||||
#include "cutlass/gemm/volta884_multiply_add.h"
|
||||
#include "cutlass/zip_fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// Layout of multiplicand A
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Indicates matrix transform on multiplicand A
|
||||
MatrixTransform::Kind TransformA,
|
||||
/// Data type of multiplicand A
|
||||
typename ScalarA_,
|
||||
/// Layout of multiplicand B
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
MatrixTransform::Kind TransformB,
|
||||
/// Data type of multiplicand B
|
||||
typename ScalarB_,
|
||||
/// Data type of accumulators
|
||||
typename ScalarC_,
|
||||
/// If true, A operand is conjugated
|
||||
bool ConjugateA = false,
|
||||
/// If true, B operand is conjugated
|
||||
bool ConjugateB = false,
|
||||
/// If true, infinite results are saturated to +-MAX_FLOAT
|
||||
bool SatFinite = false>
|
||||
struct Volta884ComplexMultiplyAdd {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
|
||||
/// Most of the Volta884 code assumes interleaved 32x32 tiles
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Shape of an individual warp-wide mma.sync instruction
|
||||
typedef Shape<4, 16, 16> InstructionShape;
|
||||
|
||||
/// Shape of a warp-level matrix multiply operation
|
||||
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
|
||||
|
||||
/// Verify WarpTile is a multiple of fundamental 32x32 interleaved tile
|
||||
static_assert(!(WarpTile::kH % InterleavedTileShape::kH) &&
|
||||
!(WarpTile::kW % InterleavedTileShape::kW) && WarpTile::kD == 4,
|
||||
"WarpTile must be a multiple of InterleavedTileShape.");
|
||||
|
||||
/// Layout of A multiplicand
|
||||
static MatrixLayout::Kind const kLayoutA = LayoutA;
|
||||
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
static MatrixTransform::Kind const kTransformA = TransformA;
|
||||
|
||||
/// Layout of B multiplicand
|
||||
static MatrixLayout::Kind const kLayoutB = LayoutB;
|
||||
|
||||
/// Indicates matrix transform on multiplicand B
|
||||
static MatrixTransform::Kind const kTransformB = TransformB;
|
||||
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
|
||||
/// If true, infinite results are saturated to +-MAX_FLOAT
|
||||
static bool const kSatFinite = SatFinite;
|
||||
|
||||
/// Hard-coded comptue type supported on Volta
|
||||
static arch::ComputeType::Kind const kComputeType = arch::ComputeType::kDefault;
|
||||
|
||||
/// Underlying matrix multiply-add operator
|
||||
typedef Volta884MultiplyAdd<WarpGemmShape,
|
||||
kLayoutA,
|
||||
ScalarA,
|
||||
kLayoutB,
|
||||
ScalarB,
|
||||
ScalarC>
|
||||
RealMultiplyAdd;
|
||||
|
||||
/// Fragment definition for A multiplicand
|
||||
typedef ZipFragment<typename RealMultiplyAdd::FragmentA, typename RealMultiplyAdd::FragmentA>
|
||||
FragmentA;
|
||||
|
||||
/// Fragment definition for B multiplicand
|
||||
typedef ZipFragment<typename RealMultiplyAdd::FragmentB, typename RealMultiplyAdd::FragmentB>
|
||||
FragmentB;
|
||||
|
||||
/// Fragment definition for accumulators
|
||||
typedef ZipFragment<typename RealMultiplyAdd::Accumulators,
|
||||
typename RealMultiplyAdd::Accumulators>
|
||||
Accumulators;
|
||||
|
||||
/// Number of mma.sync operations performed. See Volta884MultiplyAdd::Iterations for details.
|
||||
typedef typename RealMultiplyAdd::Iterations Iterations;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Volta884ComplexMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& A,
|
||||
FragmentB const& B,
|
||||
Accumulators const& C,
|
||||
Accumulators& D) {
|
||||
RealMultiplyAdd op;
|
||||
|
||||
// complex-valued multiply-add
|
||||
op.multiply_add(A.first, B.first, C.first, D.first);
|
||||
op.multiply_add(A.first, B.second, C.second, D.second, kTransformB == MatrixTransform::kConjugate);
|
||||
op.multiply_add(A.second, B.first, C.second, D.second, kTransformA == MatrixTransform::kConjugate);
|
||||
op.multiply_add(A.second, B.second, C.first, D.first,
|
||||
!((kTransformA == MatrixTransform::kConjugate) ^ (kTransformB == MatrixTransform::kConjugate)));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Complex-valued epilogue
|
||||
template <typename Accumulator, typename WarpDelta, typename Iterations>
|
||||
struct Volta884ComplexNaiveEpilogue {
|
||||
/// Accumulator data type
|
||||
typedef Accumulator ScalarC;
|
||||
|
||||
/// Output accumulator type
|
||||
typedef Accumulator ScalarD;
|
||||
|
||||
/// BLAS Scalar type
|
||||
typedef Accumulator Scalar;
|
||||
|
||||
/// Real-valued epilogue
|
||||
typedef Volta884NaiveEpilogue<Accumulator, WarpDelta, Iterations> RealEpilogue;
|
||||
|
||||
/// Params object
|
||||
struct Params {
|
||||
/// Parameters for the real-valued part
|
||||
typename RealEpilogue::Params real;
|
||||
|
||||
/// Parameters for the imaginary-valued part
|
||||
typename RealEpilogue::Params imag;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE Params() {}
|
||||
|
||||
/// Constructs from params object
|
||||
CUTLASS_HOST_DEVICE Params(typename RealEpilogue::Params const& _real,
|
||||
typename RealEpilogue::Params const& _imag)
|
||||
: real(_real), imag(_imag) {}
|
||||
|
||||
/// Construct from pointers
|
||||
CUTLASS_HOST_DEVICE Params(ScalarC* _real, int _ldr, ScalarC* _imag, int _ldi)
|
||||
: real(_real, _ldr), imag(_imag, _ldi) {}
|
||||
|
||||
/// Construct from pointers
|
||||
CUTLASS_HOST_DEVICE Params(
|
||||
platform::complex<Scalar> const &alpha,
|
||||
platform::complex<Scalar> const &beta,
|
||||
ScalarC const *real_C,
|
||||
int real_ldc,
|
||||
ScalarC const *imag_C,
|
||||
int imag_ldc,
|
||||
ScalarD *real_D,
|
||||
int real_ldd,
|
||||
ScalarD *imag_D,
|
||||
int imag_ldd
|
||||
):
|
||||
real(real_D, real_ldd, alpha.real(), beta.real()),
|
||||
imag(imag_D, imag_ldd, alpha.real(), beta.real()) { }
|
||||
|
||||
/// Initializer method
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(
|
||||
platform::complex<Scalar> const &alpha,
|
||||
platform::complex<Scalar> const &beta,
|
||||
ScalarC const *real_C,
|
||||
int real_ldc,
|
||||
ScalarC const *imag_C,
|
||||
int imag_ldc,
|
||||
ScalarD *real_D,
|
||||
int real_ldd,
|
||||
ScalarD *imag_D,
|
||||
int imag_ldd
|
||||
) {
|
||||
|
||||
real = typename RealEpilogue::Params(real_D, real_ldd, alpha.real(), beta.real());
|
||||
imag = typename RealEpilogue::Params(imag_D, imag_ldd, alpha.real(), beta.real());
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared stoarge
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Accumulator fragment definition
|
||||
typedef ZipFragment<
|
||||
typename RealEpilogue::Accumulators,
|
||||
typename RealEpilogue::Accumulators> Accumulators;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Epilogue for real part
|
||||
RealEpilogue real;
|
||||
|
||||
/// Epilogue for imaginary part
|
||||
RealEpilogue imag;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a complex-valued epilogue
|
||||
CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(
|
||||
Params const& _params, Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: real(_params.real, _problem_size), imag(_params.imag, _problem_size) {}
|
||||
|
||||
/// Constructs a complex-valued epilogue
|
||||
CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(ScalarC* _real,
|
||||
int _ldr,
|
||||
ScalarC* _imag,
|
||||
int _ldi,
|
||||
Coord<3> const& _problem_size = make_Coord(1024,
|
||||
1024,
|
||||
1024))
|
||||
: real(_real, _ldr, _problem_size), imag(_imag, _ldi, _problem_size) {}
|
||||
|
||||
/// Constructs a complex-valued epilogue
|
||||
CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(Params const& _params,
|
||||
SharedStorage& shared_storage,
|
||||
Coord<3> const& _problem_size = make_Coord(1024,
|
||||
1024,
|
||||
1024))
|
||||
: real(_params.real, _problem_size), imag(_params.imag, _problem_size) {}
|
||||
|
||||
/// Sets accumulators to zero
|
||||
CUTLASS_DEVICE void clear(Accumulators& C) {
|
||||
C.first.clear();
|
||||
C.second.clear();
|
||||
}
|
||||
|
||||
/// Naive load operation for debugging
|
||||
CUTLASS_DEVICE void load(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
real.load(C.first, threadblock_offset);
|
||||
imag.load(C.second, threadblock_offset);
|
||||
}
|
||||
|
||||
/// Naive store operation for debugging
|
||||
CUTLASS_DEVICE void store(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
real.store(C.first, threadblock_offset);
|
||||
imag.store(C.second, threadblock_offset);
|
||||
}
|
||||
|
||||
/// CUTLASS Epilogue interface
|
||||
CUTLASS_DEVICE void epilogue(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
|
||||
int batch_id = 0) {
|
||||
real.store(C.first, threadblock_offset);
|
||||
imag.store(C.second, threadblock_offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
|
@ -0,0 +1,152 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing to shared memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/zip_fragment.h"
|
||||
#include "cutlass/zip_tensor_ref.h"
|
||||
#include "cutlass/zip_tile_iterator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename Iterator_,
|
||||
/// The transformer to be applied after the data has been copied from shared memory.
|
||||
typename Transformer_ = Copy<typename Iterator_::Fragment>,
|
||||
/// Number of increments before iterator wraps - zero indicates no wrapping
|
||||
int StageCount = 1>
|
||||
struct Volta884ComplexSharedLoadStream {
|
||||
/// The load iterator.
|
||||
typedef Iterator_ RealIterator;
|
||||
|
||||
/// Zips two real-valued iterators together
|
||||
typedef ZipTileIterator<RealIterator, RealIterator> Iterator;
|
||||
|
||||
/// The transformer.
|
||||
typedef Transformer_ RealTransformer;
|
||||
|
||||
/// Zips two transfoerms
|
||||
typedef ZipConvert<RealTransformer, RealTransformer> Transformer;
|
||||
|
||||
/// Number of increments before iterator wraps - zero indicates no wrapping
|
||||
static int const kStageCount = StageCount;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename Iterator::Fragment FetchedFragment;
|
||||
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
|
||||
/// Reference type
|
||||
typedef ZipTensorRef<
|
||||
TensorRef<half, 4>,
|
||||
TensorRef<half, 4>
|
||||
> TensorRef;
|
||||
|
||||
/// Parameters passed from host
|
||||
struct Params { };
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator for loading fragments for warp-level matrix multiply-accumulate
|
||||
Iterator iterator;
|
||||
|
||||
/// Fetched fragment
|
||||
FetchedFragment fetched[2];
|
||||
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
|
||||
/// Transformed fragment
|
||||
TransformedFragment transformed[2];
|
||||
|
||||
/// Counts the number of stages
|
||||
int stage_index;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Volta884ComplexSharedLoadStream() : stage_index(0) {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Volta884ComplexSharedLoadStream(Params const &_params,
|
||||
TensorRef const &ref)
|
||||
: iterator(ref), stage_index(0) {}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(int step) {
|
||||
iterator.load(fetched[step % 2],
|
||||
make_Coord(step + stage_index * Iterator::First::VectorizedShape::kD, 0, 0, 0));
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
transformer.transform(fetched[step % 2], transformed[step % 2]);
|
||||
}
|
||||
|
||||
/// Gets the transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment &fragment(int step) { return transformed[step % 2]; }
|
||||
|
||||
/// Gets the transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment const &fragment(int step) const { return transformed[step % 2]; }
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
++stage_index;
|
||||
if (kStageCount && stage_index == StageCount) {
|
||||
stage_index = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
|
@ -0,0 +1,771 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with the computed matrix product.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/tile_stream.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
|
||||
#include "cutlass/gemm/mma_shared_stream.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Abstraction to select accumulators from an accumulator tile for each iteration fo the epilogue
|
||||
template <typename WarpGemmShape, typename WarpDelta, typename Scalar>
|
||||
struct Volta884SelectAccumulators;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Selects accumulators from Volta mma.sync.F32 layout
|
||||
template <typename WarpGemmShape_, typename WarpDelta_>
|
||||
struct Volta884SelectAccumulators<WarpGemmShape_, WarpDelta_, float> {
|
||||
/// Shape of the warp-level matrix multiply operation
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Describes tiling of warp elements
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Data type of scalar
|
||||
typedef float Scalar;
|
||||
|
||||
//
|
||||
// Derived types and constants
|
||||
//
|
||||
|
||||
/// (Actual) number of accumulators held by each individual thread
|
||||
static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize;
|
||||
|
||||
/// Accumulators fragment
|
||||
typedef Fragment<Scalar, kAccumulatorsPerThread> Accumulators;
|
||||
|
||||
/// Number of warps
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Interleaved mma.sync shape
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Hard-coded for FP32 layouts
|
||||
typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 4> Elements;
|
||||
|
||||
/// Number of elements
|
||||
static int const kElements = ShapeCount<Elements>::kCount;
|
||||
|
||||
/// Slice of accumulators
|
||||
typedef Fragment<Scalar, kElements> Fragment;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Selects accumulators for a given iteration of the epilogue
|
||||
CUTLASS_DEVICE
|
||||
Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
|
||||
Fragment frag;
|
||||
|
||||
static int const kAccumPerOp = 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Elements::kH; ++j) {
|
||||
|
||||
// selects the 32x32 tile
|
||||
Coord<2> tile_32x32 = make_Coord(idx[0] / 8, j);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Elements::kW; ++i) {
|
||||
Coord<2> mma_op = make_Coord(((idx[0] >> 1) & 1), i / 2);
|
||||
|
||||
int element = ((i & 1) << 1) | (idx[0] & 1) | (idx[0] & 4);
|
||||
|
||||
int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]);
|
||||
|
||||
frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx];
|
||||
}
|
||||
}
|
||||
|
||||
return frag;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Selects accumulators from Volta mma.sync.F16 layout
|
||||
template <typename WarpGemmShape_, typename WarpDelta_>
|
||||
struct Volta884SelectAccumulators<WarpGemmShape_, WarpDelta_, half> {
|
||||
/// Shape of the warp-level matrix multiply operation
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Describes tiling of warp elements
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Data type of accumulator elements
|
||||
typedef half Scalar;
|
||||
|
||||
//
|
||||
// Derived types and constants
|
||||
//
|
||||
|
||||
/// (Actual) number of accumulators held by each individual thread
|
||||
static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize;
|
||||
|
||||
/// Accumulators fragment
|
||||
typedef Fragment<Scalar, kAccumulatorsPerThread> Accumulators;
|
||||
|
||||
/// Number of warps
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Interleaved mma.sync shape
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Hard-coded for FP16 layouts
|
||||
typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 2> Elements;
|
||||
|
||||
/// Number of elements
|
||||
static int const kElements = ShapeCount<Elements>::kCount;
|
||||
|
||||
/// Slice of accumulators
|
||||
typedef Fragment<Scalar, kElements> Fragment;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Selects accumulators for a given iteration of the epilogue
|
||||
CUTLASS_DEVICE
|
||||
Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
|
||||
Fragment frag;
|
||||
|
||||
static int const kAccumPerOp = 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Elements::kH; ++j) {
|
||||
|
||||
// selects the 32x32 tile
|
||||
Coord<2> tile_32x32 = make_Coord(idx[0] / 16, j);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Elements::kW; ++i) {
|
||||
|
||||
Coord<2> mma_op = make_Coord(((idx[0] >> 2) & 1), i & 1);
|
||||
|
||||
int element = (idx[0] & 3) | ((idx[0] >> 1) & 4);
|
||||
|
||||
int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]);
|
||||
|
||||
frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx];
|
||||
}
|
||||
}
|
||||
|
||||
return frag;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
//
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The warp-level GEMM tile
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Size of vector to load or store
|
||||
int AccessSize,
|
||||
/// The accumulators fragment type - implies accumulator layout
|
||||
typename Accumulators_>
|
||||
struct Volta884EpilogueGlobalTileTraits;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Global tile traits specialized for Volta mma.sync.F32 layout
|
||||
template <
|
||||
/// The warp-level GEMM tile
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Size of vector to load or store
|
||||
int AccessSize>
|
||||
struct Volta884EpilogueGlobalTileTraits<WarpGemmTile_, WarpDelta_, AccessSize, float> {
|
||||
/// Shape of warp-scoped GEMM tile
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Structure of MMA
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Access size of input/output elements
|
||||
static int const kAccessSize = AccessSize;
|
||||
|
||||
/// Scalar type of accumulators - used to imply accumulator layout, not the data
|
||||
typedef float Accumulators;
|
||||
|
||||
/// Strides for immediate offset computation
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
//typedef Shape<2, 2, 1, 1> Iterations;
|
||||
|
||||
/// Hard-coded pitch between Volta mma.sync Quad Pair tiles
|
||||
static int const kMmaQuadPairWidth = 16;
|
||||
|
||||
/// Hard-coded pitch between warp tiles
|
||||
static int const kInterleavedTileWidth = 32;
|
||||
|
||||
/// Number of actual threads
|
||||
static int const kThreadCount = (WarpDelta::kH * WarpDelta::kW) * kWarpSize;
|
||||
|
||||
/// Shape of the tile
|
||||
typedef Shape<2 * WarpDelta::kH, 2, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<2 * WarpDelta::kH,
|
||||
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
|
||||
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
|
||||
1> Iterations;
|
||||
|
||||
/// Delta between accesses
|
||||
typedef Shape<kMmaQuadPairWidth, 2, WarpDelta::kW * kWarpSize, 1> Delta;
|
||||
|
||||
/// Number of warps in threadblock
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Custom thread-offset function
|
||||
struct ThreadOffset {
|
||||
CUTLASS_DEVICE
|
||||
Coord<4> operator()() {
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int residual_w = (tid / (Tile::kW));
|
||||
int offset_w = (tid % (Tile::kW));
|
||||
|
||||
int offset_h = (residual_w % Tile::kH);
|
||||
int offset_d = (residual_w / Tile::kH);
|
||||
|
||||
Coord<4> offset = make_Coord(offset_d * Delta::kD, offset_h * Delta::kH, offset_w, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Global tile traits specialized for Volta mma.sync.F16 layout
|
||||
template <
|
||||
/// The warp-level GEMM tile
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Size of vector to load or store
|
||||
int AccessSize>
|
||||
struct Volta884EpilogueGlobalTileTraits<WarpGemmTile_, WarpDelta_, AccessSize, half> {
|
||||
/// Shape of warp-scoped GEMM tile
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Structure of MMA tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Access size of input/output elements
|
||||
static int const kAccessSize = AccessSize;
|
||||
|
||||
/// Scalar type of accumulators - used to imply accumulator layout, not the data
|
||||
typedef half Accumulators;
|
||||
|
||||
/// Hard-coded pitch between Volta mma.sync Quad Pair tiles
|
||||
static int const kMmaQuadPairWidth = 16;
|
||||
|
||||
/// Hard-coded pitch between warp tiles
|
||||
static int const kInterleavedTileWidth = 32;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreadCount = kWarpSize * WarpDelta::kH * WarpDelta::kW;
|
||||
|
||||
/// Shape of the tile
|
||||
typedef Shape<1, 2 * WarpDelta::kH, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
|
||||
|
||||
/// Strides for immediate offset computation
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<
|
||||
1,
|
||||
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
|
||||
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
|
||||
1> Iterations;
|
||||
|
||||
|
||||
/// Delta between thread-level accesses
|
||||
typedef typename platform::conditional<
|
||||
kThreadCount >= Tile::kW,
|
||||
Shape<1, kMmaQuadPairWidth * (kThreadCount / Tile::kW), 1, 1>,
|
||||
Shape<1, kMmaQuadPairWidth, kThreadCount, 1>
|
||||
>::type Delta;
|
||||
|
||||
/// Number of warps in threadblock
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Custom thread-offset function
|
||||
struct ThreadOffset {
|
||||
CUTLASS_DEVICE
|
||||
Coord<4> operator()() {
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int residual_w = (tid / (Tile::kW));
|
||||
int offset_w = (tid % (Tile::kW));
|
||||
|
||||
int offset_h = (residual_w % Tile::kH);
|
||||
int offset_d = (residual_w / Tile::kH);
|
||||
|
||||
Coord<4> offset = make_Coord(offset_d * Delta::kD, offset_h * kMmaQuadPairWidth, offset_w, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Global offset functor for Volta884 epilogues
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename WarpDelta, typename AccumulatorType>
|
||||
struct Volta884EpilogueGlobalOffset;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue. Specialized for Volta mma.sync.F32
|
||||
template <typename WarpDelta>
|
||||
struct Volta884EpilogueGlobalOffset<WarpDelta, float> {
|
||||
|
||||
/// mma.sync instructions are arranged as spatially overlapping 32x32 tiles
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Coord<3> operator()(Coord<2> const &iteration) const {
|
||||
|
||||
int h = iteration[0];
|
||||
|
||||
// C++ needs a better way to express bit swizzling
|
||||
int h_offset = ((h & 1) | ((h & 2) << 1) | (((h & 4) >> 2) * 8) |
|
||||
(((h & 8) >> 3) * WarpDelta::kH * MmaTileShape::kH));
|
||||
|
||||
return make_Coord(0, h_offset, iteration[1] * MmaTileShape::kW * WarpDelta::kW);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue. Specialized for Volta mma.sync.F16
|
||||
template <typename WarpDelta>
|
||||
struct Volta884EpilogueGlobalOffset<WarpDelta, half> {
|
||||
|
||||
/// mma.sync instructions are arranged as spatially overlapping 32x32 tiles
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Coord<3> operator()(Coord<2> const &iteration) const {
|
||||
|
||||
int h = iteration[0];
|
||||
|
||||
// C++ needs a better way to express bit swizzling
|
||||
int h_offset = (h & 15) | (h & 16) * 2 * WarpDelta::kH;
|
||||
|
||||
Coord<3> offset = make_Coord(0, h_offset, iteration[1] * MmaTileShape::kW * WarpDelta::kW);
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Epilogue traits for Volta884 epilogue
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue traits for Volta884 GEMMs
|
||||
template <
|
||||
/// The threadblock GEMM tile
|
||||
typename OutputTile_,
|
||||
/// The warp-level GEMM tile
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// The accumulators fragment type.
|
||||
typename Accumulators_,
|
||||
/// Selects a slice of accumulators
|
||||
typename SelectAccumulators_,
|
||||
/// The iterator to load source matrix from global memory.
|
||||
typename GlobalLoadStreamC_,
|
||||
/// The iterator to store the final GEMM computation to global memory.
|
||||
typename GlobalStoreStreamD_,
|
||||
/// The stream to store matrix product to shared memory
|
||||
typename SharedStoreStreamD_,
|
||||
/// The stream to load the matrix product from shared memory
|
||||
typename SharedLoadStreamD_,
|
||||
/// The functor computing an element-wise operation on the matrix product
|
||||
typename Functor_,
|
||||
/// Global memory mapping function
|
||||
typename GlobalDataLayout_ = MatrixLayout::ColumnMajor,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct Volta884EpilogueTraits {
|
||||
/// The output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
|
||||
/// The warp-level GEMM tile
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp accumulator elements
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// The accumulators fragment type.
|
||||
typedef Accumulators_ Accumulators;
|
||||
|
||||
/// Selects a subset of accumulators for a given epilogue iteration
|
||||
typedef SelectAccumulators_ SelectAccumulators;
|
||||
|
||||
/// The iterator to load source matrix from global memory.
|
||||
typedef GlobalLoadStreamC_ GlobalLoadStreamC;
|
||||
|
||||
/// The iterator to store the final GEMM computation to global memory.
|
||||
typedef GlobalStoreStreamD_ GlobalStoreStreamD;
|
||||
|
||||
/// The stream to store matrix product to shared memory
|
||||
typedef SharedStoreStreamD_ SharedStoreStreamD;
|
||||
|
||||
/// The stream to load the matrix product from shared memory
|
||||
typedef SharedLoadStreamD_ SharedLoadStreamD;
|
||||
|
||||
/// The functor computing an element-wise operation on the matrix product
|
||||
typedef Functor_ Functor;
|
||||
|
||||
/// Global memory mapping function
|
||||
typedef GlobalDataLayout_ GlobalDataLayout;
|
||||
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
|
||||
/// The scalar type of the source accumulator matrix.
|
||||
typedef typename GlobalLoadStreamC::Iterator::Scalar ScalarC;
|
||||
|
||||
/// The scalar type of the destination accumulator matrix.
|
||||
typedef typename GlobalStoreStreamD::Iterator::Scalar ScalarD;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
static bool const kFp32Arrangement = sizeof(typename SelectAccumulators::Scalar) == 4;
|
||||
|
||||
/// Skew elements
|
||||
static int const kSkew = 2;
|
||||
|
||||
/// Number of columns of accumulators stored/loaded depends on the accumulator arrangement
|
||||
static int const kColumnsPerWarp = (kFp32Arrangement ? 4 : 2);
|
||||
|
||||
/// mma.sync instructions are arranged as spatially overlapping 32x32 tiles
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Cover an entire warp-level tile
|
||||
typedef Shape<1,
|
||||
WarpGemmTile::kH / kColumnsPerWarp, // iterates over 32x32 accumulator tiles along N dimension
|
||||
1, // iterates over 32x32 accumulator tiles along M dimension
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Skew is needed to reduce bank conflicts to SMEM - this shape depends on accumulator layout
|
||||
typedef Shape<1,
|
||||
WarpDelta::kH * kColumnsPerWarp, // multiple columns in the gemm N dimension
|
||||
WarpDelta::kW * WarpGemmTile::kW + kSkew, // rows in the gemm M dimension
|
||||
1
|
||||
> EpilogueTileAllocation;
|
||||
|
||||
/// Parameters structure initialized on the host
|
||||
struct Params {
|
||||
/// The params for the C iterator.
|
||||
typename GlobalLoadStreamC::Params load_stream_c;
|
||||
|
||||
/// The params for the D global iterator.
|
||||
typename GlobalStoreStreamD::Params store_stream_d;
|
||||
|
||||
/// Epilogue functor params
|
||||
typename Functor::Params functor;
|
||||
|
||||
/// The params for the D shared store iterator.
|
||||
typename SharedStoreStreamD::Params shared_store_stream_d;
|
||||
|
||||
/// The params for the D shared load stream.
|
||||
typename SharedLoadStreamD::Params shared_load_stream_d;
|
||||
|
||||
///
|
||||
long long int batch_stride_C;
|
||||
|
||||
///
|
||||
long long int batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Helper constructor taking pointer, stride for source and destination matrices and functor
|
||||
/// params
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ScalarD *ptr_D,
|
||||
int ldd,
|
||||
ScalarC const *ptr_C,
|
||||
int ldc,
|
||||
typename Functor::Params _functor = Functor::Params())
|
||||
: load_stream_c(), store_stream_d(), functor(_functor) {}
|
||||
|
||||
/// Setup the params.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
batch_stride_C = desc.batch_stride_C;
|
||||
batch_stride_D = desc.batch_stride_D;
|
||||
|
||||
// The parameters for the functor.
|
||||
int error_code = functor.initialize(desc);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for C.
|
||||
error_code = load_stream_c.iterator.initialize(
|
||||
desc.C.data(), desc.C.leading_dim(), desc.C.leading_dim(), 1
|
||||
);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for D.
|
||||
return store_stream_d.iterator.initialize(
|
||||
desc.D.data(), desc.D.leading_dim(), desc.D.leading_dim(), 1
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory buffer used by epilogue
|
||||
typedef TileAllocation<
|
||||
typename SharedStoreStreamD::Iterator::Scalar,
|
||||
EpilogueTileAllocation> SharedStorage;
|
||||
|
||||
/// Functor computing the offset from the threadblock origin per iteration of
|
||||
/// the epilogue.
|
||||
typedef Volta884EpilogueGlobalOffset<WarpDelta, typename SelectAccumulators::Scalar> GlobalOffset;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Volta884 Epilogue helper
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits, typename AccumulatorType>
|
||||
struct Volta884EpiloguePredicateFunctor;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor specialized for the predicate arrangement in the Volta884 epilogue
|
||||
template <typename TileTraits>
|
||||
struct Volta884EpiloguePredicateFunctor<TileTraits, float> {
|
||||
/// Dimensions of the bounding volume
|
||||
Coord<3> bounds;
|
||||
|
||||
/// Constructs a predicate functor given the bounds of a tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Volta884EpiloguePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {}
|
||||
|
||||
/// Computes the predicate given the logical position of an access
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator()(Coord<3> const &iteration, Coord<3> const &offset) const {
|
||||
return
|
||||
(iteration[0] * TileTraits::Delta::kD + iteration[1] * TileTraits::Delta::kH +
|
||||
offset[1] < bounds[1]) &&
|
||||
(iteration[2] * TileTraits::Delta::kW + offset[2] < bounds[2]);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor specialized for the predicate arrangement in the Volta884 epilogue
|
||||
template <typename TileTraits>
|
||||
struct Volta884EpiloguePredicateFunctor<TileTraits, half> {
|
||||
/// Dimensions of the bounding volume
|
||||
Coord<3> bounds;
|
||||
|
||||
/// Constructs a predicate functor given the bounds of a tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Volta884EpiloguePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {}
|
||||
|
||||
/// Computes the predicate given the logical position of an access
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator()(Coord<3> const &iteration, Coord<3> const &offset) const {
|
||||
return iteration[1] * TileTraits::Delta::kH + offset[1] < bounds[1] &&
|
||||
iteration[2] * TileTraits::Delta::kW + offset[2] < bounds[2];
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Volta884 Epilogue helper
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to define the traits for a Volta884 Epilogue
|
||||
template <
|
||||
typename GemmConfig_,
|
||||
typename EpilogueFunctor_,
|
||||
typename MultiplyAdd_ = typename GemmConfig_::MultiplyAdd,
|
||||
typename Index_ = int>
|
||||
struct Volta884GemmEpilogueTraitsHelper {
|
||||
|
||||
/// Configuration object defining GEMM properties
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
|
||||
/// Warp-level tile
|
||||
typedef typename GemmConfig::AccumulatorsPerWarp WarpGemmShape;
|
||||
|
||||
/// Warp delta
|
||||
typedef typename ShapeDiv<
|
||||
typename GemmConfig::OutputTile,
|
||||
WarpGemmShape>::Shape WarpDelta;
|
||||
|
||||
/// Thread-block scoped tile
|
||||
typedef typename cutlass::ShapeMul<
|
||||
WarpGemmShape,
|
||||
WarpDelta
|
||||
>::Shape OutputTile;
|
||||
|
||||
/// Multiply-add operation
|
||||
typedef MultiplyAdd_ MultiplyAdd;
|
||||
|
||||
/// Epilogue functor
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// Traits for global tile access
|
||||
typedef cutlass::gemm::Volta884EpilogueGlobalTileTraits<
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
1,
|
||||
typename MultiplyAdd::ScalarC
|
||||
> EpilogueGlobalTileTraits;
|
||||
|
||||
/// Iterator to load a slice of the C matrix from global memory
|
||||
typedef cutlass::TileLoadIterator<
|
||||
EpilogueGlobalTileTraits,
|
||||
typename GemmConfig::ScalarC,
|
||||
cutlass::IteratorAdvance::kW,
|
||||
cutlass::MemorySpace::kGlobal
|
||||
> TileLoadIteratorC;
|
||||
|
||||
/// Conversion from C data type to accumulator data type
|
||||
typedef Convert<
|
||||
typename TileLoadIteratorC::Fragment,
|
||||
Fragment<typename MultiplyAdd::ScalarC, TileLoadIteratorC::Fragment::kElements>
|
||||
> ConvertSourceFragment;
|
||||
|
||||
/// Iterator to store a slice of the D matrix to global memory
|
||||
typedef cutlass::TileStoreIterator<
|
||||
EpilogueGlobalTileTraits,
|
||||
typename GemmConfig::ScalarD,
|
||||
cutlass::IteratorAdvance::kW,
|
||||
cutlass::MemorySpace::kGlobal
|
||||
> TileStoreIteratorD;
|
||||
|
||||
/// Conversion from accumulator data type to D data type
|
||||
typedef Convert<
|
||||
Fragment<typename MultiplyAdd::ScalarC, TileStoreIteratorD::Fragment::kElements>,
|
||||
typename TileStoreIteratorD::Fragment
|
||||
> ConvertDestinationFragment;
|
||||
|
||||
/// Defines traits for an epilogue of a Volta884 GEMM
|
||||
typedef cutlass::gemm::Volta884EpilogueTraits<
|
||||
OutputTile,
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
typename MultiplyAdd::Accumulators,
|
||||
cutlass::gemm::Volta884SelectAccumulators<
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
typename MultiplyAdd::ScalarC
|
||||
>,
|
||||
cutlass::PredicatedTileLoadStream<
|
||||
TileLoadIteratorC,
|
||||
cutlass::gemm::Volta884EpiloguePredicateFunctor<
|
||||
EpilogueGlobalTileTraits,
|
||||
typename MultiplyAdd::ScalarC>,
|
||||
ConvertSourceFragment
|
||||
>,
|
||||
cutlass::PredicatedTileStoreStream<
|
||||
TileStoreIteratorD,
|
||||
cutlass::gemm::Volta884EpiloguePredicateFunctor<
|
||||
EpilogueGlobalTileTraits,
|
||||
typename MultiplyAdd::ScalarC>,
|
||||
ConvertDestinationFragment
|
||||
>,
|
||||
cutlass::TileStoreStream<
|
||||
cutlass::gemm::Volta884EpilogueSharedStoreIterator<
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
typename MultiplyAdd::ScalarC,
|
||||
typename MultiplyAdd::ScalarC
|
||||
>
|
||||
>,
|
||||
cutlass::TileLoadStream<
|
||||
cutlass::gemm::Volta884EpilogueSharedLoadIterator<
|
||||
WarpGemmShape,
|
||||
WarpDelta,
|
||||
typename MultiplyAdd::ScalarC,
|
||||
1,
|
||||
typename MultiplyAdd::ScalarC
|
||||
>
|
||||
>,
|
||||
Functor
|
||||
> EpilogueTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
|
@ -0,0 +1,585 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/gemm/clear_accumulators.h"
|
||||
#include "cutlass/gemm/gemm_config.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_stream_pair.h"
|
||||
#include "cutlass/gemm/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
#include "cutlass/kernel_launch.h"
|
||||
|
||||
#include "cutlass/gemm/gemm_desc.h"
|
||||
#include "cutlass/gemm/volta884_multiplicand.h"
|
||||
#include "cutlass/gemm/volta884_multiply_add.h"
|
||||
#include "cutlass/gemm/mma_global_stream.h"
|
||||
#include "cutlass/gemm/mma_shared_stream.h"
|
||||
#include "cutlass/gemm/volta884_gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/mma_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_mainloop.h"
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines configuration for Volta884 GEMM
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The source matrix type type.
|
||||
typename ScalarC_,
|
||||
/// The destination matrix type
|
||||
typename ScalarD_,
|
||||
/// Number of stages in shared memory
|
||||
int StageCount,
|
||||
|
||||
/// If true, kernel is launched with CUDA launch bounds specified
|
||||
bool kLaunchBounds = true,
|
||||
/// If true, residue is computed in mainloop. If false, separate loops are instantiated.
|
||||
bool kResidueSeparate = true,
|
||||
/// Is residue performed in prologue?
|
||||
bool kResidueInProlog = false>
|
||||
struct Volta884GemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
ScalarC_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The threadblock tile size
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
Volta884MultiplyAdd<WarpGemmShape_,
|
||||
LayoutA,
|
||||
half,
|
||||
LayoutB,
|
||||
half,
|
||||
Accumulator_>,
|
||||
/// The number of scalars per LDG for A.
|
||||
8,
|
||||
/// The number of scalars per STS for A.
|
||||
8,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
8,
|
||||
/// The number of scalars per STS for B.
|
||||
8,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of scalars per STS for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of scalars per LDS for D.
|
||||
16 / int(sizeof(ScalarD_)),
|
||||
/// The number of stages in shared memory.
|
||||
StageCount,
|
||||
/// If true, separate mainloop is instantiated
|
||||
kResidueSeparate,
|
||||
/// If true, compute residue in prolog
|
||||
kResidueInProlog,
|
||||
/// Launch bounds not used
|
||||
kLaunchBounds> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines components of Volta884 GEMM
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The input matrix type type.
|
||||
typename ScalarC_,
|
||||
/// The output matrix type type.
|
||||
typename ScalarD_,
|
||||
/// Number of buffers in shared memory to use
|
||||
int StageCount,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<Accumulator_>,
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typename BlockSwizzle_ = IdentityBlockSwizzle,
|
||||
/// Selectively enables launch bounds
|
||||
bool LaunchBounds = false
|
||||
>
|
||||
struct Volta884GemmTraits {
|
||||
/// This traits
|
||||
typedef Volta884GemmTraits<
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
OutputTile_,
|
||||
WarpGemmShape_,
|
||||
Accumulator_,
|
||||
ScalarC_,
|
||||
ScalarD_,
|
||||
StageCount,
|
||||
EpilogueFunctor_,
|
||||
BlockSwizzle_,
|
||||
LaunchBounds> This_;
|
||||
/// The struct that consumes this Traits
|
||||
typedef typename cutlass::gemm::GemmMainloop<This_> KernelClass;
|
||||
|
||||
/// Layout of multiplicand A matrix
|
||||
static MatrixLayout::Kind const kLayoutA = LayoutA;
|
||||
|
||||
/// Layout of multiplicand B matrix
|
||||
static MatrixLayout::Kind const kLayoutB = LayoutB;
|
||||
|
||||
/// Dimensions of threadblock tile (concept Shape)
|
||||
typedef OutputTile_ OutputTile;
|
||||
|
||||
/// Shape of warp-level accumulators
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Multiplicand A scalar type
|
||||
typedef half ScalarA;
|
||||
|
||||
/// Multiplicand B scalar type
|
||||
typedef half ScalarB;
|
||||
|
||||
/// Data type of internal accumulator
|
||||
typedef Accumulator_ Accumulator;
|
||||
|
||||
/// Data type of input accumulator matrix operand
|
||||
typedef ScalarC_ ScalarC;
|
||||
|
||||
/// Data type of output accumulator matrix operand
|
||||
typedef ScalarD_ ScalarD;
|
||||
|
||||
/// Shape of individual mma.sync instruction
|
||||
typedef Shape<4, 16, 16> InstructionShape;
|
||||
|
||||
/// Tile size for an individual warp-level multiply-add
|
||||
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
|
||||
|
||||
/// Defines properties about GEMM needed by host code
|
||||
typedef Volta884GemmConfig<kLayoutA,
|
||||
kLayoutB,
|
||||
OutputTile,
|
||||
WarpGemmShape,
|
||||
Accumulator,
|
||||
ScalarC,
|
||||
ScalarD,
|
||||
StageCount,
|
||||
LaunchBounds>
|
||||
GemmConfig;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Partitioning of threadblock into warps
|
||||
typedef typename ShapeDiv<OutputTile, WarpGemmShape>::Shape WarpDelta;
|
||||
|
||||
/// Number of warps per threadblock
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Defines iterators for A matrix
|
||||
typedef Volta884Multiplicand<GemmOperand::kA, kLayoutA, OutputTile, WarpTile, kWarpCount, WarpDelta>
|
||||
MultiplicandA;
|
||||
|
||||
/// Defines iterators for B matrix
|
||||
typedef Volta884Multiplicand<GemmOperand::kB, kLayoutB, OutputTile, WarpTile, kWarpCount, WarpDelta>
|
||||
MultiplicandB;
|
||||
|
||||
//
|
||||
// GemmTraits mandatory type definitions
|
||||
//
|
||||
|
||||
/// Maps hardware threadblocks to logical partitions of the GEMM
|
||||
typedef BlockSwizzle_ BlockSwizzle;
|
||||
|
||||
/// Clears accumulators
|
||||
typedef ClearAccumulators<ScalarC> ClearAccumulators;
|
||||
|
||||
/// Loads multiplicands from global memory
|
||||
typedef GlobalLoadStreamPair<
|
||||
MMAGlobalLoadStream<GemmOperand::kA,
|
||||
kLayoutA,
|
||||
typename MultiplicandA::LoadIterator,
|
||||
Copy<typename MultiplicandA::LoadIterator::Fragment>,
|
||||
typename MultiplicandA::StoreIterator,
|
||||
StageCount>,
|
||||
MMAGlobalLoadStream<GemmOperand::kB,
|
||||
kLayoutB,
|
||||
typename MultiplicandB::LoadIterator,
|
||||
Copy<typename MultiplicandB::LoadIterator::Fragment>,
|
||||
typename MultiplicandB::StoreIterator,
|
||||
StageCount>,
|
||||
GemmConfig::kResidueInProlog >
|
||||
GlobalLoadStream;
|
||||
|
||||
/// Memory needed to store the threadblock-scoped GEMM tile
|
||||
typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
|
||||
union MainLoopStorage {
|
||||
|
||||
/// Stores the threadblock tile
|
||||
ThreadblockTileStorage threadblock_tile;
|
||||
|
||||
/// Storage for GEMM global stream
|
||||
typename GlobalLoadStream::SharedStorage global_to_shared_stream;
|
||||
};
|
||||
|
||||
/// Loads multiplicands from shared memory
|
||||
typedef SharedStreamPair<
|
||||
MMASharedLoadStream<typename MultiplicandA::WarpLoadIterator,
|
||||
Copy<typename MultiplicandA::WarpLoadIterator::Fragment>,
|
||||
StageCount>,
|
||||
MMASharedLoadStream<typename MultiplicandB::WarpLoadIterator,
|
||||
Copy<typename MultiplicandB::WarpLoadIterator::Fragment>,
|
||||
StageCount> >
|
||||
SharedStream;
|
||||
|
||||
// Multiply-add object specialized for Volta mma.sync
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
|
||||
#if 0
|
||||
/// Naive epilogue for updating the output matrix
|
||||
typedef cutlass::gemm::Volta884NaiveEpilogue<ScalarC,
|
||||
typename MultiplicandA::WarpDelta,
|
||||
typename MultiplyAdd::Iterations>
|
||||
Epilogue;
|
||||
#else
|
||||
|
||||
/// Efficient epilogue
|
||||
typedef cutlass::gemm::MMAEpilogue<
|
||||
typename Volta884GemmEpilogueTraitsHelper<
|
||||
GemmConfig,
|
||||
EpilogueFunctor_
|
||||
>::EpilogueTraits
|
||||
> Epilogue;
|
||||
|
||||
#endif
|
||||
|
||||
/// Parameters structure
|
||||
struct Params : public KernelLaunchConfiguration {
|
||||
/// The dimensions of the GEMM.
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// The K range for every partition except the last one
|
||||
int partitionK_range;
|
||||
|
||||
/// The params for the global load stream
|
||||
typename GlobalLoadStream::Params global_to_shared_stream;
|
||||
|
||||
/// The params for the shared load stream
|
||||
typename SharedStream::Params shared_stream;
|
||||
|
||||
/// The params for the epilogue.
|
||||
typename Epilogue::Params epilogue;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE Params(GemmDesc_ const& desc) {
|
||||
initialize(desc);
|
||||
}
|
||||
|
||||
/// Initialize the Params struct
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
|
||||
// Problem size
|
||||
problem_size = desc.problem_size;
|
||||
|
||||
// there is no partitionK in the default case
|
||||
partitionK_range = problem_size[0];
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Compute offset to residue
|
||||
Index gemm_k = problem_size[0];
|
||||
Index offset_to_residue = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
Index offset_to_residue_last_partition = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
|
||||
// Initialize parameters objects for
|
||||
global_to_shared_stream.stream_a.initialize(
|
||||
desc.A,
|
||||
desc.batch_stride_A,
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition);
|
||||
|
||||
global_to_shared_stream.stream_b.initialize(
|
||||
desc.B,
|
||||
desc.batch_stride_B,
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition);
|
||||
|
||||
// The epilogue.
|
||||
epilogue.initialize(desc);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Helper to construct a GEMM params using a BLAS-like API
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, 1),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
TensorRef<ScalarD, 2>(d_d, ldd)
|
||||
);
|
||||
|
||||
return this->initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a batched GEMM params
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
long long int batch_stride_A,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
long long int batch_stride_B,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
long long int batch_stride_C,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
long long int batch_stride_D,
|
||||
Index batch_count) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
make_Coord(k, n, m, batch_count),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
batch_stride_A,
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
batch_stride_B,
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
batch_stride_C,
|
||||
TensorRef<ScalarD, 2>(d_d, ldd),
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
return this->initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a partitionedK GEMM params
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_, Index partitionK_multiple_ = 1) {
|
||||
// partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
|
||||
// the problem_size of each batch is (lastK_size, n, m)
|
||||
// add more comments here
|
||||
// the k range for every batch excpet the last one
|
||||
|
||||
partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
|
||||
partitionK_range = partitionK_range - (partitionK_range % partitionK_multiple_);
|
||||
// the k range of the last batch
|
||||
// int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
|
||||
int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
|
||||
|
||||
assert((partitionK_range % partitionK_multiple_) == 0);
|
||||
assert(partitionK_range > 0);
|
||||
assert((lastK_range % partitionK_multiple_) == 0);
|
||||
assert(lastK_range > 0);
|
||||
|
||||
int k_size = lastK_range;
|
||||
int lda = partitonK_desc.A.stride(0);
|
||||
int ldb = partitonK_desc.B.stride(0);
|
||||
int ldc = partitonK_desc.C.stride(0);
|
||||
int ldd = partitonK_desc.D.stride(0);
|
||||
int n = partitonK_desc.problem_size.n();
|
||||
|
||||
long long int batch_stride_A = (kLayoutA == cutlass::MatrixLayout::kColumnMajor) ? lda * partitionK_range : partitionK_range;
|
||||
long long int batch_stride_B = (kLayoutB == cutlass::MatrixLayout::kColumnMajor) ? partitionK_range : partitionK_range * ldb;
|
||||
long long int batch_stride_C = ldc * n;
|
||||
long long int batch_stride_D = ldd * n;
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
//we pass lastK_size as per batch K. there is also a range that will match partitionK_size
|
||||
GemmCoord(k_size, partitonK_desc.problem_size.n(), partitonK_desc.problem_size.m(), partitionK_count_),
|
||||
partitonK_desc.alpha,
|
||||
partitonK_desc.A,
|
||||
batch_stride_A,
|
||||
partitonK_desc.B,
|
||||
batch_stride_B,
|
||||
partitonK_desc.beta,
|
||||
partitonK_desc.C,
|
||||
batch_stride_C,
|
||||
partitonK_desc.D,
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
// Set the problem size.
|
||||
problem_size = desc.problem_size;
|
||||
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Compute offset to residue.
|
||||
// partitionK_range <= problem_size[0]
|
||||
Index gemm_k = problem_size[0];
|
||||
Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
|
||||
|
||||
// Initialize parameters objects for
|
||||
int error_code = global_to_shared_stream.stream_a.initialize(
|
||||
desc.A,
|
||||
desc.batch_stride_A,
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
error_code = global_to_shared_stream.stream_b.initialize(
|
||||
desc.B,
|
||||
desc.batch_stride_B,
|
||||
offset_to_residue,
|
||||
offset_to_residue_last_partition
|
||||
);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// The epilogue.
|
||||
return epilogue.initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a partitionedK GEMM params
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
Index partitionK_count_,
|
||||
Index partitionK_multiple_ = 1) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, 1),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
TensorRef<ScalarD, 2>(d_d, ldd)
|
||||
);
|
||||
|
||||
|
||||
return this->initialize(desc, partitionK_count_, partitionK_multiple_);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage
|
||||
union SharedStorage {
|
||||
/// Storage required during mainloop phase
|
||||
MainLoopStorage main_loop;
|
||||
|
||||
/// Shared storage needed for epilogue
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
|
||||
if (StageCount < 2) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) {
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
|
@ -0,0 +1,298 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
#include "cutlass/gemm/mma_global_tile.h"
|
||||
#include "cutlass/gemm/volta884_shared_tile.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands
|
||||
template <
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
GemmOperand::Kind Operand,
|
||||
/// Specifies layout of data in source memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile,
|
||||
/// Specifies warp tile shape
|
||||
typename WarpTile,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp tiles
|
||||
typename WarpDelta_>
|
||||
struct Volta884Multiplicand;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands for A.column_major
|
||||
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
|
||||
struct Volta884Multiplicand<GemmOperand::kA,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// Thread-block tile shape
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Warp-level matrix multiply-add shape
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Total number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
//
|
||||
// Thread-block load iterator
|
||||
//
|
||||
typedef
|
||||
typename MMAThreadblockCongruousLoad<kOperand, Tile_, WarpCount, WarpDelta::kW>::Iterator
|
||||
LoadIterator;
|
||||
|
||||
//
|
||||
// Thread-block store iterator
|
||||
//
|
||||
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta::kW>
|
||||
StoreIterator;
|
||||
|
||||
//
|
||||
// Warp-level load iterator
|
||||
//
|
||||
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta>
|
||||
WarpLoadIterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands for B.row_major
|
||||
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
|
||||
struct Volta884Multiplicand<GemmOperand::kB,
|
||||
MatrixLayout::kRowMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// Thread-block tile shape
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Warp-level matrix multiply-add shape
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Total number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
//
|
||||
// Thread-block load iterator
|
||||
//
|
||||
typedef
|
||||
typename MMAThreadblockCongruousLoad<kOperand, Tile_, WarpCount, WarpDelta::kH>::Iterator
|
||||
LoadIterator;
|
||||
|
||||
//
|
||||
// Thread-block store iterator
|
||||
//
|
||||
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta::kH>
|
||||
StoreIterator;
|
||||
|
||||
//
|
||||
// Warp-level load iterator
|
||||
//
|
||||
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta>
|
||||
WarpLoadIterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands for A.row_major
|
||||
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
|
||||
struct Volta884Multiplicand<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// Thread-block tile shape
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Warp-level matrix multiply-add shape
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Total number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
//
|
||||
// Thread-block load iterator
|
||||
//
|
||||
typedef
|
||||
typename MMAThreadblockCrosswiseLoad<kOperand, Tile_, WarpCount, WarpDelta::kW>::Iterator
|
||||
LoadIterator;
|
||||
|
||||
//
|
||||
// Thread-block store iterator
|
||||
//
|
||||
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta::kW>
|
||||
StoreIterator;
|
||||
|
||||
//
|
||||
// Warp-level load iterator
|
||||
//
|
||||
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta>
|
||||
WarpLoadIterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines iterators for loading and storing multiplicands for B.row_major
|
||||
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
|
||||
struct Volta884Multiplicand<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// Thread-block tile shape
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Warp-level matrix multiply-add shape
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Total number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
//
|
||||
// Thread-block load iterator
|
||||
//
|
||||
typedef
|
||||
typename MMAThreadblockCrosswiseLoad<kOperand, Tile_, WarpCount, WarpDelta::kH>::Iterator
|
||||
LoadIterator;
|
||||
|
||||
//
|
||||
// Thread-block store iterator
|
||||
//
|
||||
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta::kH>
|
||||
StoreIterator;
|
||||
|
||||
//
|
||||
// Warp-level load iterator
|
||||
//
|
||||
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
|
||||
kLayout,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta>
|
||||
WarpLoadIterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
|
@ -0,0 +1,704 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements warp-level multiply-accumulate operations using Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// Layout of A multiplicand
|
||||
MatrixLayout::Kind LayoutA,
|
||||
/// Data type of A multiplicand
|
||||
typename ScalarA_,
|
||||
/// Layout of B multiplicand
|
||||
MatrixLayout::Kind LayoutB,
|
||||
/// Data type of A multiplicand
|
||||
typename ScalarB_,
|
||||
/// Data type of accumulators
|
||||
typename ScalarC_>
|
||||
struct Volta884MultiplyAdd {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
|
||||
/// Shape of a warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
|
||||
/// Most of the Volta884 code assumes interleaved 32x32 tiles
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Shape of an individual warp-wide Volta mma.sync instruction
|
||||
typedef Shape<4, 16, 16> InstructionShape;
|
||||
|
||||
/// Shape of a warp-level matrix multiply operation
|
||||
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
|
||||
|
||||
/// Verify WarpTile is a multiple of fundamental 32x32 interleaved tile
|
||||
static_assert(!(WarpTile::kH % InterleavedTileShape::kH) &&
|
||||
!(WarpTile::kW % InterleavedTileShape::kW) && WarpTile::kD == 4,
|
||||
"WarpTile must be a multiple of InterleavedTileShape.");
|
||||
|
||||
/// Layout of A multiplicand
|
||||
static MatrixLayout::Kind const kLayoutA = LayoutA;
|
||||
/// Layout of B multiplicand
|
||||
static MatrixLayout::Kind const kLayoutB = LayoutB;
|
||||
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
|
||||
/// Hard-coded comptue type supported on Volta
|
||||
static arch::ComputeType::Kind const kComputeType = arch::ComputeType::kDefault;
|
||||
|
||||
/// Defines a warp-level matrix multiply-accumulate operation performed by a warp.
|
||||
//
|
||||
// The layout is as follows. The entire warp performs a 64x64x4 GEMM using Volta mma.sync macros
|
||||
// arranged as a 2x2 tile of adjacent, 32x32x4 matrix products. These are implemented as a
|
||||
// 2x2 arrangement of spatially interleaved Volta mma.sync macros.
|
||||
//
|
||||
// The Iterations shape maps to the following dimensions of the above warp-level GEMM:
|
||||
//
|
||||
// kC: number of rows of Volta mma.sync macros in 32x32x4 tile
|
||||
// kW: number of columns of Volta mma.sync macros in 32x32x4 tile
|
||||
// kH: number of rows of 32x32x4 macros in larger 64x64x4 tile
|
||||
// kD: number of columns of 32x32x4 macros in larger 64x64x4 tile
|
||||
//
|
||||
// A column-major ordering would arrange C and H as the inner-most loops, with W and D as the
|
||||
// outer-most.
|
||||
//
|
||||
typedef Shape<WarpTile::kH / InterleavedTileShape::kH,
|
||||
WarpTile::kW / InterleavedTileShape::kW,
|
||||
InterleavedTileShape::kH / InstructionShape::kH,
|
||||
InterleavedTileShape::kW / InstructionShape::kW>
|
||||
Iterations;
|
||||
|
||||
/// Number of multiplicand elements per instruction
|
||||
static int const kMultElementsPerInst = 4;
|
||||
|
||||
/// Number of multiplicand elements per instruction
|
||||
static int const kAccumElementsPerInst = 8;
|
||||
|
||||
/// Fragment definition for A multiplicand
|
||||
typedef Fragment<ScalarA, Iterations::kH * Iterations::kC * kMultElementsPerInst> FragmentA;
|
||||
|
||||
/// Fragment definition for B multiplicand
|
||||
typedef Fragment<ScalarB, Iterations::kW * Iterations::kD * kMultElementsPerInst> FragmentB;
|
||||
|
||||
/// Fragment definition for accumulators
|
||||
typedef Fragment<ScalarC, ShapeCount<Iterations>::kCount * kAccumElementsPerInst> Accumulators;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Volta884MultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = (-)a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& A,
|
||||
FragmentB const& B,
|
||||
Accumulators const& C,
|
||||
Accumulators& D,
|
||||
bool negate = false) {
|
||||
// Guard conditional needed for __hneg2
|
||||
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) { // Outer column
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) { // Inner column
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h_raw = 0; h_raw < Iterations::kH; ++h_raw) { // Outer row
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c_raw = 0; c_raw < Iterations::kC; ++c_raw) { // Inner row
|
||||
|
||||
int op_col = (w + Iterations::kW * d);
|
||||
|
||||
// Column-major serpentine sequence to maximize reuse of B operand.
|
||||
int h = h_raw;
|
||||
int c = c_raw;
|
||||
|
||||
if (op_col & 1) {
|
||||
h = Iterations::kH - h_raw - 1;
|
||||
c = Iterations::kC - c_raw - 1;
|
||||
}
|
||||
|
||||
int op_row = (c + Iterations::kC * h);
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
ScalarA operand_A[kMultElementsPerInst];
|
||||
|
||||
reinterpret_cast<uint64_t&>(operand_A[0]) =
|
||||
reinterpret_cast<uint64_t const&>(A[op_row * kMultElementsPerInst]);
|
||||
|
||||
if (negate) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kMultElementsPerInst; i += 2) {
|
||||
reinterpret_cast<half2&>(operand_A[i]) =
|
||||
__hneg2(reinterpret_cast<half2 const&>(A[op_row * kMultElementsPerInst + i]));
|
||||
}
|
||||
}
|
||||
|
||||
// Issue a Volta mma.sync instruction
|
||||
arch::mma<InstructionShape,
|
||||
kLayoutA,
|
||||
ScalarA,
|
||||
kLayoutB,
|
||||
ScalarB,
|
||||
ScalarC,
|
||||
kComputeType>(
|
||||
|
||||
operand_A, //&A[op_row * kMultElementsPerInst],
|
||||
&B[op_col * kMultElementsPerInst],
|
||||
&C[op_idx * kAccumElementsPerInst],
|
||||
&D[op_idx * kAccumElementsPerInst]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <=750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Accumulator, typename WarpDelta, typename Iterations>
|
||||
struct Volta884NaiveEpilogue;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Naive epilogue specialized for f32 accumulators - may be considered authoritative mapping of
|
||||
/// accumulators to mma.sync operations.
|
||||
template <typename WarpDelta_, typename Iterations_>
|
||||
struct Volta884NaiveEpilogue<float, WarpDelta_, Iterations_> {
|
||||
/// Accumulator data type
|
||||
typedef float ScalarC;
|
||||
|
||||
/// Output accumulator type
|
||||
typedef float ScalarD;
|
||||
|
||||
/// BLAS Scalar type
|
||||
typedef float Scalar;
|
||||
|
||||
/// Delta among warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Number of Volta mma.sync operations
|
||||
typedef Iterations_ Iterations;
|
||||
|
||||
/// Most of the Volta884 code assumes interleaved 32x32 tiles
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Number of multiplicand elements per instruction
|
||||
static int const kAccumElementsPerInst = 8;
|
||||
|
||||
/// Fragment definition for accumulators
|
||||
typedef Fragment<ScalarC, ShapeCount<Iterations>::kCount * kAccumElementsPerInst> Accumulators;
|
||||
|
||||
/// Params object
|
||||
struct Params {
|
||||
/// Pointer to output matrix
|
||||
ScalarC* ptr;
|
||||
|
||||
/// stride
|
||||
int ldm;
|
||||
|
||||
/// Scalar alpha
|
||||
float alpha;
|
||||
|
||||
/// Scalar beta
|
||||
float beta;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : ptr(0), ldm(0), alpha(1), beta(0) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0)
|
||||
: ptr(_ptr), ldm(_ldm), alpha(_alpha), beta(_beta) {}
|
||||
|
||||
/// Initialize method
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) {
|
||||
ptr = _ptr;
|
||||
ldm = _ldm;
|
||||
alpha = _alpha;
|
||||
beta = _beta;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
ptr = reinterpret_cast<ScalarC*>(desc.D.data());
|
||||
ldm = desc.D.leading_dim();
|
||||
alpha = desc.alpha;
|
||||
beta = desc.beta;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared stoarge
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Helper used to compute initial offset for each thread
|
||||
struct InitialOffset {
|
||||
int row_offset;
|
||||
int col_offset;
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
InitialOffset() {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
int quad_id = (lane_id >> 2);
|
||||
int quadpair_id = (quad_id & 0x3);
|
||||
|
||||
int quadpair_row = (quadpair_id & 1);
|
||||
int quadpair_col = (quadpair_id >> 1);
|
||||
int quad_hilo = (quad_id >> 2) & 1;
|
||||
|
||||
// compute initial offset
|
||||
int warp_row_offset = (warp_id % WarpDelta::kW) * InterleavedTileShape::kW;
|
||||
int warp_col_offset = (warp_id / WarpDelta::kW) * InterleavedTileShape::kH;
|
||||
|
||||
int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 1);
|
||||
int thread_col_offset = (quadpair_col * 2) * 8 + (lane_id & 2);
|
||||
|
||||
row_offset = warp_row_offset + thread_row_offset;
|
||||
col_offset = warp_col_offset + thread_col_offset;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
/// Problem size
|
||||
Coord<3> problem_size;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params,
|
||||
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: params(_params), problem_size(_problem_size) {}
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(ScalarC* _ptr,
|
||||
int _ldm,
|
||||
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: params(_ptr, _ldm), problem_size(_problem_size) {}
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params,
|
||||
SharedStorage& shared_storage,
|
||||
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: params(_params), problem_size(_problem_size) {}
|
||||
|
||||
/// Sets accumulators to zero
|
||||
CUTLASS_DEVICE void clear(Accumulators& C) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
C[op_idx * kAccumElementsPerInst + reg] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Naive load operation for debugging
|
||||
CUTLASS_DEVICE void load(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
InitialOffset initial;
|
||||
|
||||
initial.row_offset += threadblock_offset[2];
|
||||
initial.col_offset += threadblock_offset[1];
|
||||
|
||||
ScalarC const* load_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
|
||||
|
||||
// loads accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
ScalarC const* op_ptr = load_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
|
||||
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
|
||||
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
int tr = (reg & 2) + c * 4;
|
||||
int tc = (reg & 1) + (reg & 4) * 2 + w * 4;
|
||||
|
||||
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
|
||||
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
|
||||
|
||||
if (row < problem_size[2] && column < problem_size[1]) {
|
||||
C[op_idx * kAccumElementsPerInst + reg] = op_ptr[tr + tc * params.ldm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Naive store operation for debugging
|
||||
CUTLASS_DEVICE void store(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
InitialOffset initial;
|
||||
|
||||
initial.row_offset += threadblock_offset[2];
|
||||
initial.col_offset += threadblock_offset[1];
|
||||
|
||||
ScalarC* store_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
|
||||
|
||||
// store out accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
ScalarC* op_ptr = store_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
|
||||
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
|
||||
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
int tr = (reg & 2) + c * 4;
|
||||
int tc = (reg & 1) + (reg & 4) * 2 + w * 4;
|
||||
|
||||
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
|
||||
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
|
||||
|
||||
if (row < problem_size[2] && column < problem_size[1]) {
|
||||
op_ptr[tr + tc * params.ldm] =
|
||||
params.alpha * C[op_idx * kAccumElementsPerInst + reg] +
|
||||
params.beta * op_ptr[tr + tc * params.ldm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// CUTLASS Epilogue interface
|
||||
CUTLASS_DEVICE void epilogue(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
store(C, threadblock_offset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void epilogue(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
store(C, threadblock_offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Naive epilogue specialized for f16 accumulators - may be considered authoritative mapping of
|
||||
/// accumulators to mma.sync operations.
|
||||
template <typename WarpDelta_, typename Iterations_>
|
||||
struct Volta884NaiveEpilogue<half, WarpDelta_, Iterations_> {
|
||||
/// Accumulator data type
|
||||
typedef half ScalarC;
|
||||
|
||||
/// Output accumulator type
|
||||
typedef half ScalarD;
|
||||
|
||||
/// BLAS Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
/// Delta among warp tiles
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Number of Volta mma.sync operations
|
||||
typedef Iterations_ Iterations;
|
||||
|
||||
/// Most of the Volta884 code assumes interleaved 32x32 tiles
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Number of multiplicand elements per instruction
|
||||
static int const kAccumElementsPerInst = 8;
|
||||
|
||||
/// Fragment definition for accumulators
|
||||
typedef Fragment<ScalarC, ShapeCount<Iterations>::kCount * kAccumElementsPerInst> Accumulators;
|
||||
|
||||
/// Params object
|
||||
struct Params {
|
||||
/// Pointer to output matrix
|
||||
ScalarC* ptr;
|
||||
|
||||
/// stride
|
||||
int ldm;
|
||||
|
||||
/// Scalar alpha
|
||||
half alpha;
|
||||
|
||||
/// Scalar beta
|
||||
half beta;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : ptr(0), ldm(0), alpha(1), beta(0) {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0)
|
||||
: ptr(_ptr), ldm(_ldm), alpha(_alpha), beta(_beta) {}
|
||||
|
||||
/// Initialize method
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) {
|
||||
ptr = _ptr;
|
||||
ldm = _ldm;
|
||||
alpha = _alpha;
|
||||
beta = _beta;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
ptr = reinterpret_cast<ScalarC*>(desc.D.data());
|
||||
ldm = desc.D.leading_dim();
|
||||
alpha = desc.alpha;
|
||||
beta = desc.beta;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared stoarge
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Helper used to compute initial offset for each thread
|
||||
struct InitialOffset {
|
||||
int row_offset;
|
||||
int col_offset;
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
InitialOffset() {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
int quad_id = (lane_id >> 2);
|
||||
int quadpair_id = (quad_id & 0x3);
|
||||
|
||||
int quadpair_row = (quadpair_id & 1);
|
||||
int quadpair_col = (quadpair_id >> 1);
|
||||
int quad_hilo = (quad_id >> 2) & 1;
|
||||
|
||||
// compute initial offset
|
||||
int warp_row_offset = (warp_id % WarpDelta::kW) * InterleavedTileShape::kW;
|
||||
int warp_col_offset = (warp_id / WarpDelta::kW) * InterleavedTileShape::kH;
|
||||
|
||||
int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 3);
|
||||
int thread_col_offset = (quadpair_col * 2) * 8;
|
||||
|
||||
row_offset = warp_row_offset + thread_row_offset;
|
||||
col_offset = warp_col_offset + thread_col_offset;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
/// Problem size
|
||||
Coord<3> problem_size;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params)
|
||||
: params(_params), problem_size(make_Coord(1024, 1024, 1024)) {}
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(ScalarC* _ptr, int _ldm)
|
||||
: params(_ptr, _ldm), problem_size(make_Coord(1024, 1024, 1024)) {}
|
||||
|
||||
/// Computes initial offset for each thread
|
||||
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params,
|
||||
SharedStorage& shared_storage,
|
||||
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
|
||||
: params(_params), problem_size(_problem_size) {}
|
||||
|
||||
/// Sets accumulators to zero
|
||||
CUTLASS_DEVICE void clear(Accumulators& C) { C.clear(); }
|
||||
|
||||
/// Naive load operation for debugging
|
||||
CUTLASS_DEVICE void load(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
InitialOffset initial;
|
||||
|
||||
initial.row_offset += threadblock_offset[2];
|
||||
initial.col_offset += threadblock_offset[1];
|
||||
|
||||
ScalarC const* load_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
|
||||
|
||||
// loads accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
ScalarC const* op_ptr = load_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
|
||||
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
|
||||
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
int tr = c * 4;
|
||||
int tc = (reg & 3) + (reg & 4) * 2 + w * 4;
|
||||
|
||||
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
|
||||
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
|
||||
|
||||
if (row < problem_size[2] && column < problem_size[1]) {
|
||||
C[op_idx * kAccumElementsPerInst + reg] = op_ptr[tr + tc * params.ldm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Naive store operation for debugging
|
||||
CUTLASS_DEVICE void store(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
InitialOffset initial;
|
||||
|
||||
initial.row_offset += threadblock_offset[2];
|
||||
initial.col_offset += threadblock_offset[1];
|
||||
|
||||
ScalarC* store_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
|
||||
|
||||
// store out accumulators
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
ScalarC* op_ptr = store_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
|
||||
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
|
||||
|
||||
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
|
||||
int tr = c * 4;
|
||||
int tc = (reg & 3) + (reg & 4) * 2 + w * 4;
|
||||
|
||||
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
|
||||
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
|
||||
|
||||
if (row < problem_size[2] && column < problem_size[1]) {
|
||||
op_ptr[tr + tc * params.ldm] =
|
||||
params.alpha * C[op_idx * kAccumElementsPerInst + reg] +
|
||||
params.beta * op_ptr[tr + tc * params.ldm];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// CUTLASS Epilogue interface
|
||||
CUTLASS_DEVICE void epilogue(Accumulators const& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
store(C, threadblock_offset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void epilogue(Accumulators& C,
|
||||
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
|
||||
store(C, threadblock_offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
|
@ -0,0 +1,142 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Warp-scoped shared memory load iterators
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
///! Iterator to store a thread-block scoped fragment to shared memory
|
||||
template <
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
GemmOperand::Kind Operand,
|
||||
/// Specifies layout of data in source memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
int WarpDelta>
|
||||
struct Volta884ThreadblockMultiplicandStoreIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator to load a fragment for each warp-level tile
|
||||
template <
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
GemmOperand::Kind Operand,
|
||||
/// Specifies layout of data in source memory
|
||||
MatrixLayout::Kind Layout,
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile,
|
||||
/// Specifies the warp tile shape
|
||||
typename WarpTile,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
typename WarpDelta>
|
||||
struct Volta884WarpMultiplicandLoadIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Fully-specialized implementations extracted in separate headers.
|
||||
//
|
||||
|
||||
#include "cutlass/gemm/volta884_shared_tile_contiguous.h"
|
||||
#include "cutlass/gemm/volta884_shared_tile_crosswise.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Epilogue shared memory iterators
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stores an accumulator fragment to shared memory
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_,
|
||||
/// Data type of mma.sync accumulator - this is used to infer layout.
|
||||
typename Accumulator_>
|
||||
struct Volta884EpilogueSharedStoreIterator;
|
||||
|
||||
/// Loads an accumulator fragment from shared memory
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_,
|
||||
/// Number of scalar elements loaded
|
||||
int AccessSize_,
|
||||
/// Data type of mma.sync accumulator - this is used to infer layout.
|
||||
typename Accumulator_>
|
||||
struct Volta884EpilogueSharedLoadIterator;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
//
|
||||
// Partially-specialized implementations extracted in separate header.
|
||||
//
|
||||
|
||||
#include "cutlass/gemm/volta884_shared_tile_epilogue.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
@ -0,0 +1,974 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Congruous loading
|
||||
//
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Store iterator specialized for A.column_major
|
||||
template <
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile_,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
int WarpDelta>
|
||||
struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta> {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
static int const kWarpDelta = WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Swizzled store iterator
|
||||
struct ThreadOffset {
|
||||
__device__ Coord<4> operator()() const {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
|
||||
int k_idx = warp_id;
|
||||
|
||||
// This is an 8-element vector within one 32x32 tile
|
||||
int vec_idx = lane_id & 3;
|
||||
int vec_col = (vec_idx / 2);
|
||||
|
||||
int t4t3 = (lane_id >> 3);
|
||||
int col_rotate = ((lane_id >> 1) & 2) | (lane_id & 1);
|
||||
|
||||
int t_col = (vec_col << 2) | (col_rotate ^ t4t3);
|
||||
|
||||
Coord<4> offset = make_Coord(k_idx, col_rotate, t_col, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Stored tile has a structure designed for efficient MIO storing and loading
|
||||
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
|
||||
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
|
||||
kAccessSize, // Eight banks of MIO
|
||||
kAccessSize>
|
||||
VectorizedShape; // 128b stores
|
||||
|
||||
/// Offset between stores
|
||||
typedef Shape<WarpCount, 1, 1, 1> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<(VectorizedShape::kD / WarpCount), (OperandShape::kW >> 6), 1, 1> Iterations;
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
|
||||
|
||||
/// Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Tensor reference
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Strides into expected SMEM tile
|
||||
typedef typename ShapeStrides<VectorizedShape, 1>::Shape Strides;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to element type
|
||||
Scalar *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar *_pointer = 0)
|
||||
: pointer(_pointer),
|
||||
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
|
||||
|
||||
/// Constructs a params object from a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a store iterator
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator(
|
||||
Params const &_params,
|
||||
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
|
||||
ThreadOffset offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
// Compute initial thread offset
|
||||
Coord<4> offset = offset_func();
|
||||
|
||||
params.pointer += (_block_offset + offset).template dot<int>(params.stride);
|
||||
}
|
||||
|
||||
/// Stores a fragment
|
||||
CUTLASS_DEVICE void store(Fragment const &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
|
||||
FragmentConstIterator frag_iterator(fragment);
|
||||
|
||||
// Iterate over each store
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
int idx = w + Iterations::kW * h;
|
||||
|
||||
int row = idx * 4;
|
||||
|
||||
Coord<4> sts_offset =
|
||||
make_Coord(d, row, 0, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
Store<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, 0)),
|
||||
params.pointer,
|
||||
params.stride.template dot<int>(sts_offset + offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments store iterator to next tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) {
|
||||
params.pointer +=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to next tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(); }
|
||||
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) {
|
||||
return increment(count);
|
||||
}
|
||||
|
||||
/// Increments store iterator to previous tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) {
|
||||
params.pointer -=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to subsequent tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(); }
|
||||
|
||||
/// Decrements to previous tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) {
|
||||
return decrement(count);
|
||||
}
|
||||
|
||||
/// Stores a fragment and increments in the K dimension
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &store_post_increment(
|
||||
Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
|
||||
store(fragment, offset);
|
||||
return increment();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator to load a fragment for each warp-level tile specialized for A.column_major
|
||||
template <
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile_,
|
||||
/// Specifies the warp tile shape
|
||||
typename WarpTile_,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
typename WarpDelta_>
|
||||
struct Volta884WarpMultiplicandLoadIterator<GemmOperand::kA,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Shape of warp-tile matrix operation
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Hard-coded tile shape
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Two SMEM read pointers are needed
|
||||
static int const kPointerCount = (WarpDelta::kW == 1 ? 2 : 1);
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Swizzled store iterator
|
||||
struct ThreadOffset {
|
||||
/// Compute thread offset coordinate for each pointer
|
||||
CUTLASS_DEVICE Coord<4> operator()(int pointer_idx = 0) const {
|
||||
// Determine the warp's reading location within the SMEM tile
|
||||
int warp_id = ((threadIdx.x >> 5) % WarpDelta::kW);
|
||||
|
||||
// This is an 8-element vector within one 32x32 tile
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
int vec_row = (lane_id >> 4);
|
||||
int vec_col = ((lane_id & 4) >> 2);
|
||||
|
||||
int tile_row = pointer_idx * 2 + vec_row;
|
||||
|
||||
// Column rotation function
|
||||
int t_col = (vec_col * 4);
|
||||
if (pointer_idx == 1 || (WarpDelta::kW > 1 && (warp_id & 1))) {
|
||||
vec_row |= 2;
|
||||
}
|
||||
|
||||
t_col = t_col | ((lane_id & 3) ^ vec_row);
|
||||
|
||||
Coord<4> offset = make_Coord(0, warp_id * 2 + tile_row, t_col, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Stored tile has a structure designed for efficient MIO storing and loading
|
||||
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
|
||||
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
|
||||
kAccessSize, // Eight banks of MIO
|
||||
kAccessSize>
|
||||
VectorizedShape; // 128b stores
|
||||
|
||||
/// Offset between acceses
|
||||
typedef typename platform::conditional<WarpDelta::kW == 1,
|
||||
Shape<1, 0, 0, 0>,
|
||||
Shape<1, 2 * WarpDelta::kW, 0, 0> >::type Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, WarpTile::kW / InterleavedTileShape::kW, 1, 1> Iterations;
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
|
||||
|
||||
/// Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Tensor reference
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Strides into expected SMEM tile
|
||||
typedef typename ShapeStrides<VectorizedShape, kAccessSize>::Shape Strides;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Base pointer to SMEM allocation
|
||||
Scalar const *pointer;
|
||||
|
||||
/// SMEM strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar const *_pointer = 0)
|
||||
: pointer(_pointer),
|
||||
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
|
||||
|
||||
/// Constructs a params object from a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
// A.column requires two SMEM pointers.
|
||||
// Because Params only supplies a base pointer and strides, there is no usual params
|
||||
// data member. Instead, it is used to initialize the following.
|
||||
|
||||
/// Pointer to SMEM allocation.
|
||||
Scalar const *pointer[kPointerCount];
|
||||
|
||||
/// SMEM strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a load iterator
|
||||
CUTLASS_DEVICE Volta884WarpMultiplicandLoadIterator(
|
||||
Params const &_params,
|
||||
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
|
||||
ThreadOffset offset_func = ThreadOffset())
|
||||
: stride(_params.stride) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int idx = 0; idx < kPointerCount; ++idx) {
|
||||
Coord<4> offset = offset_func(idx);
|
||||
|
||||
pointer[idx] = _params.pointer + (_block_offset + offset).template dot<int>(stride);
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE void load(Fragment &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
// Iterate over each load
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
// Pointers mapped to Iterations::kH dimension
|
||||
Scalar const *_pointer = pointer[(kPointerCount == 2 ? h : 0)];
|
||||
|
||||
Coord<4> lds_offset =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
Load<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)),
|
||||
_pointer,
|
||||
stride.template dot<int>(lds_offset + offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment and increments to next K-index
|
||||
CUTLASS_DEVICE void load_post_increment(Fragment &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
|
||||
load(fragment, offset);
|
||||
|
||||
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
|
||||
pointer[ptr_idx] += make_Coord(1, 0, 0, 0).template dot<int>(stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Store iterator specialized for B.row_major
|
||||
template <
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile_,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
int WarpDelta>
|
||||
struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kB,
|
||||
MatrixLayout::kRowMajor,
|
||||
Tile_,
|
||||
WarpCount,
|
||||
WarpDelta> {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
static int const kWarpDelta = WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
/// Swizzled store iterator
|
||||
struct ThreadOffset {
|
||||
CUTLASS_DEVICE Coord<4> operator()() const {
|
||||
int warp_id = (threadIdx.x >> 5);
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
|
||||
int k_idx = warp_id;
|
||||
|
||||
// This is an 8-element vector within one 32x32 tile
|
||||
int vec_idx = lane_id & 3;
|
||||
int vec_col = (vec_idx / 2);
|
||||
|
||||
int t4t3 = (lane_id >> 3);
|
||||
int col_rotate = ((lane_id >> 1) & 2) | (lane_id & 1);
|
||||
|
||||
int t_col = (vec_col << 2) | (col_rotate ^ t4t3);
|
||||
|
||||
Coord<4> offset = make_Coord(k_idx, col_rotate , t_col, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Stored tile has a structure designed for efficient MIO storing and loading
|
||||
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
|
||||
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
|
||||
kAccessSize, // Eight banks of MIO
|
||||
kAccessSize>
|
||||
VectorizedShape; // 128b stores
|
||||
|
||||
/// Offset between stores
|
||||
typedef Shape<WarpCount, 1, 1, 1> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<(VectorizedShape::kD / WarpCount), (OperandShape::kW >> 6), 1, 1> Iterations;
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
|
||||
|
||||
/// Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Tensor reference
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Strides into expected SMEM tile
|
||||
typedef typename ShapeStrides<VectorizedShape, 1>::Shape Strides;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to element type
|
||||
Scalar *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar *_pointer = 0)
|
||||
: pointer(_pointer),
|
||||
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
|
||||
|
||||
/// Constructs a params object from a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a store iterator
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator(
|
||||
Params const &_params,
|
||||
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
|
||||
ThreadOffset offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
// Compute initial offset for each thread
|
||||
Coord<4> offset = offset_func();
|
||||
|
||||
params.pointer += (_block_offset + offset).template dot<int>(params.stride);
|
||||
}
|
||||
|
||||
/// Stores a fragment
|
||||
CUTLASS_DEVICE void store(Fragment const &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
|
||||
FragmentConstIterator frag_iterator(fragment);
|
||||
|
||||
// Iterate over each store
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
int idx = w + Iterations::kW * h;
|
||||
int row = idx * 4;
|
||||
|
||||
Coord<4> sts_offset =
|
||||
make_Coord(d, row, 0, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
Index _offset = params.stride.template dot<int>(sts_offset + offset);
|
||||
|
||||
Store<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, 0)),
|
||||
params.pointer,
|
||||
_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments store iterator to next tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) {
|
||||
params.pointer +=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to next tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(); }
|
||||
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) {
|
||||
return increment(count);
|
||||
}
|
||||
|
||||
/// Increments store iterator to previous tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) {
|
||||
params.pointer -=
|
||||
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to subsequent tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(); }
|
||||
|
||||
/// Decrements to previous tile
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) {
|
||||
return decrement(count);
|
||||
}
|
||||
|
||||
/// Stores a fragment and increments in the K dimension
|
||||
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &store_post_increment(
|
||||
Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
|
||||
store(fragment, offset);
|
||||
return increment();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator to load a fragment for each warp-level tile specialized for B.row_major
|
||||
template <
|
||||
/// Specifies threadblock tile shape
|
||||
typename Tile_,
|
||||
/// Specifies the warp tile shape
|
||||
typename WarpTile_,
|
||||
/// Specifies the number of participating warps
|
||||
int WarpCount,
|
||||
/// Specifies the delta between warp accesses along the outer dimension
|
||||
typename WarpDelta_>
|
||||
struct Volta884WarpMultiplicandLoadIterator<GemmOperand::kB,
|
||||
MatrixLayout::kRowMajor,
|
||||
Tile_,
|
||||
WarpTile_,
|
||||
WarpCount,
|
||||
WarpDelta_> {
|
||||
//
|
||||
// Constant and type definitions
|
||||
//
|
||||
|
||||
/// Identifies multiplicand of GEMM (A or B)
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
|
||||
/// Specifies layout of data in source memory
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// Shape of thread-block multiplicand
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Shape of warp-tile matrix operation
|
||||
typedef WarpTile_ WarpTile;
|
||||
|
||||
/// Hard-coded tile shape
|
||||
typedef Shape<4, 32, 32> InterleavedTileShape;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = WarpCount;
|
||||
|
||||
/// Delta between warp accumulator tiles along the outer dimension
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// This implementation is specialized for 128b loads
|
||||
static int const kAccessSize = 8;
|
||||
|
||||
/// Swizzled store iterator
|
||||
struct ThreadOffset {
|
||||
/// Computes the initial offset
|
||||
CUTLASS_DEVICE Coord<4> operator()(int pointer_idx) const {
|
||||
// Determine the warp's reading location within the SMEM tile
|
||||
int warp_id = ((threadIdx.x >> 5) / WarpDelta::kW);
|
||||
|
||||
// This is an 8-element vector within one 32x32 tile
|
||||
int lane_id = (threadIdx.x & 0x1f);
|
||||
int vec_row = (lane_id >> 4);
|
||||
int vec_col = ((lane_id & 8) >> 3);
|
||||
|
||||
int tile_row = pointer_idx * 2 + vec_row;
|
||||
|
||||
// Column rotation function
|
||||
int t_col = (vec_col * 4);
|
||||
if (pointer_idx == 1 || (WarpDelta::kH > 1 && (warp_id & 1))) {
|
||||
vec_row |= 2;
|
||||
}
|
||||
|
||||
t_col = t_col | ((lane_id & 3) ^ vec_row);
|
||||
Coord<4> offset = make_Coord(0, warp_id * 2 + tile_row, t_col, 0);
|
||||
|
||||
return offset;
|
||||
}
|
||||
};
|
||||
|
||||
/// Projects the threadblock tile
|
||||
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
|
||||
|
||||
/// Stored tile has a structure designed for efficient MIO storing and loading
|
||||
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
|
||||
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
|
||||
kAccessSize, // Eight banks of MIO
|
||||
kAccessSize>
|
||||
VectorizedShape; // 128b stores
|
||||
|
||||
/// Delta between accesses
|
||||
typedef typename platform::conditional<WarpDelta::kH == 1,
|
||||
Shape<1, 0, 0, 0>,
|
||||
Shape<1, 2 * WarpDelta::kH, 0, 0> >::type Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, WarpTile::kH / InterleavedTileShape::kH, 1, 1> Iterations;
|
||||
|
||||
/// Source tile traits
|
||||
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
|
||||
|
||||
/// Scalar type
|
||||
typedef half Scalar;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
//
|
||||
// Derived types
|
||||
//
|
||||
|
||||
/// Tensor reference
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Strides into expected SMEM tile
|
||||
typedef typename ShapeStrides<VectorizedShape, 1>::Shape Strides;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Number of SMEM read pointers needed
|
||||
static int const kPointerCount = (WarpDelta::kH == 1 ? 2 : 1);
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to element type
|
||||
Scalar const *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar const *_pointer = 0)
|
||||
: pointer(_pointer),
|
||||
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
|
||||
|
||||
/// Constructs a params object from a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to element type
|
||||
Scalar const *pointer[kPointerCount];
|
||||
|
||||
/// Strides
|
||||
Coord<4> stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a load iterator
|
||||
CUTLASS_DEVICE Volta884WarpMultiplicandLoadIterator(
|
||||
Params const &_params,
|
||||
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
|
||||
ThreadOffset offset_func = ThreadOffset())
|
||||
: stride(_params.stride) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
|
||||
Coord<4> offset = offset_func(ptr_idx);
|
||||
|
||||
pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot<int>(stride);
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment
|
||||
CUTLASS_DEVICE void load(Fragment &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
// Iterate over each load
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
// Pointers mapped to Iterations::kH dimension
|
||||
Scalar const *_pointer = pointer[(kPointerCount == 2 ? h : 0)];
|
||||
|
||||
Coord<4> lds_offset =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
Load<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)),
|
||||
_pointer,
|
||||
stride.template dot<int>(lds_offset + offset));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment and increments to next K-index
|
||||
CUTLASS_DEVICE void load_post_increment(Fragment &fragment,
|
||||
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
|
||||
load(fragment, offset);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
|
||||
pointer[ptr_idx] += make_Coord(1, 0, 0, 0).template dot<int>(stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,629 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
|
||||
|
||||
DO NOT INCLUDE THIS FILE DIRECTLY.
|
||||
|
||||
This file is intended to be included by <cutlass/gemm/volta884_shared_tile.h> and defines
|
||||
partial specializations for templates specified therein.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Partial specializations for FP32 accumulator layouts
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue shared memory store iterator specialized for Volta's mma.sync.FP32 layout
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_>
|
||||
struct Volta884EpilogueSharedStoreIterator<WarpGemmTile_, WarpDelta_, Scalar_, float> {
|
||||
/// Warp-scoped GEMM tile size
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp elements across threadblock
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Scalar data type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Accumulator data type (and layout)
|
||||
typedef float Accumulator;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
// Host-side params
|
||||
struct Params {};
|
||||
|
||||
/// Access size
|
||||
static int const kAccessSize = 1;
|
||||
|
||||
/// Skew elements to ensure conflict free stores
|
||||
static int const kSkew = 2;
|
||||
|
||||
/// Shape of one interleaved mma.sync tile
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Four element fragment
|
||||
typedef Shape<WarpGemmTile::kW / MmaTileShape::kW, 1, 4, 1> Iterations;
|
||||
|
||||
/// Delta separated by two elements
|
||||
typedef Shape<MmaTileShape::kW * WarpDelta::kW, 1, 2, 1> Delta;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Tensor reference type
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Base pointer to SMEM allocation
|
||||
Scalar *pointer;
|
||||
|
||||
/// Stride in shared memory
|
||||
Coord<4> strides;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
Volta884EpilogueSharedStoreIterator(Params const &_params, TensorRef const &ref)
|
||||
: pointer(ref.data()), strides(make_Coord(1, WarpDelta::kW * WarpGemmTile::kW + kSkew, 1, 1)) {
|
||||
|
||||
int warp_id = (threadIdx.x / kWarpSize);
|
||||
int lane_id = (threadIdx.x % kWarpSize);
|
||||
|
||||
Coord<4> warp_idx = make_Coord(0, warp_id / WarpDelta::kW, warp_id % WarpDelta::kW, 0);
|
||||
|
||||
Coord<4> warp_base = warp_idx * make_Coord(0, 4, MmaTileShape::kW, 0);
|
||||
|
||||
Coord<4> thread_idx = make_Coord(0,
|
||||
(((lane_id >> 1) & 4) | (lane_id & 2)) >> 1,
|
||||
(lane_id & 1) | ((lane_id >> 1) & 8) | ((lane_id << 2) & 16),
|
||||
0);
|
||||
|
||||
int offset = strides.template dot<int>(warp_base + thread_idx);
|
||||
|
||||
pointer += offset;
|
||||
}
|
||||
|
||||
/// Store to the epilogue tile.
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &fragment) const {
|
||||
FragmentConstIterator frag_iterator(fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
Coord<4> coord =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
int _offset = coord.template dot<int>(strides);
|
||||
|
||||
Store<typename Fragment::Element, kAccessSize, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, 0)), pointer,
|
||||
_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores to the epilogue tile - this iterator does not advance, so increment is null.
|
||||
CUTLASS_DEVICE
|
||||
void store_post_increment(Fragment const &fragment) { store(fragment); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue shared memory load iterator specialized for Volta's mma.sync.FP32 layout
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_,
|
||||
/// Number of elements loaded per access
|
||||
int AccessSize_>
|
||||
struct Volta884EpilogueSharedLoadIterator<WarpGemmTile_, WarpDelta_, Scalar_, AccessSize_, float> {
|
||||
/// Warp-scoped GEMM tile size
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp elements across threadblock
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Scalar data type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Accumulator data type (and layout)
|
||||
typedef float Accumulator;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
/// Number of elements accessed at once
|
||||
static int const kAccessSize = AccessSize_;
|
||||
|
||||
/// Shape of one interleaved mma.sync tile
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Total participating warps
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Total participating threads
|
||||
static int const kThreadCount = kWarpCount * kWarpSize;
|
||||
|
||||
/// Skew elements
|
||||
static int const kSkew = 2;
|
||||
|
||||
/// This tile is to be strip-mined with a swizzling function
|
||||
typedef Shape<2 * WarpDelta::kH, 2, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<2 * WarpDelta::kH,
|
||||
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
|
||||
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Delta between accesses
|
||||
typedef Shape<2, 1, kThreadCount, 1> Delta;
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment of elements to load
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
static_assert(!(kSkew % kAccessSize), "Access size must have compatible alignment with skew");
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Tensor reference type
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Host-side params
|
||||
struct Params {};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer
|
||||
Scalar const *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> strides;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
Volta884EpilogueSharedLoadIterator(Params const &_params, TensorRef const &ref)
|
||||
: pointer(ref.data()),
|
||||
strides(make_Coord((WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
|
||||
(WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
|
||||
kAccessSize,
|
||||
1)) {
|
||||
// strip-mine this tile
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int residual_w = (tid / (Tile::kW));
|
||||
int offset_w = (tid % (Tile::kW));
|
||||
|
||||
int offset_h = (residual_w % Tile::kH);
|
||||
int offset_d = (residual_w / Tile::kH);
|
||||
|
||||
Coord<4> offset = make_Coord(offset_d * Delta::kW, offset_h * Delta::kH, offset_w, 0);
|
||||
|
||||
pointer += strides.template dot<int>(offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from the epilogue tile.
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &fragment) const {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
Coord<4> coord =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kW);
|
||||
|
||||
int _offset = coord.template dot<int>(strides);
|
||||
|
||||
Load<typename Fragment::Element, kAccessSize, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), pointer, _offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment - iterator does not actually advance, so increment operation is null.
|
||||
CUTLASS_DEVICE
|
||||
void load_post_increment(Fragment &fragment) { load(fragment); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Partial specializations for FP16 accumulator layouts
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue shared memory store iterator specialized for Volta's mma.sync.FP16 layout
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_>
|
||||
struct Volta884EpilogueSharedStoreIterator<WarpGemmTile_, WarpDelta_, Scalar_, half> {
|
||||
/// Warp-scoped GEMM tile size
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp elements across threadblock
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Scalar data type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Accumulator data type (and layout)
|
||||
typedef half Accumulator;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
/// Host-side params
|
||||
struct Params {};
|
||||
|
||||
/// Dimensions of contiguous 32x32x4 Volta's mma.sync tile
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// Accumulator fragment
|
||||
typedef Shape<WarpGemmTile::kW / MmaTileShape::kW, 1, 2, 1> Iterations;
|
||||
|
||||
/// Delta separated by two elements
|
||||
typedef Shape<MmaTileShape::kW * WarpDelta::kW, 1, 4, 1> Delta;
|
||||
|
||||
/// Access size
|
||||
static int const kAccessSize = 1;
|
||||
|
||||
/// Skew elements to ensure conflict free stores
|
||||
static int const kSkew = 2;
|
||||
|
||||
/// Tensor reference type
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Fragment definition
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Base pointer to SMEM allocation
|
||||
Scalar *pointer;
|
||||
|
||||
/// Stride in shared memory
|
||||
Coord<4> strides;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
Volta884EpilogueSharedStoreIterator(Params const &_params, TensorRef const &ref)
|
||||
: pointer(ref.data()), strides(make_Coord(1, WarpGemmTile::kW * WarpDelta::kW + kSkew, 1, 1)) {
|
||||
|
||||
int warp_id = (threadIdx.x / kWarpSize);
|
||||
int lane_id = (threadIdx.x % kWarpSize);
|
||||
|
||||
int quad_id = (lane_id >> 2);
|
||||
int quadpair_id = (quad_id & 0x3);
|
||||
|
||||
int quadpair_row = (quadpair_id & 1);
|
||||
int quadpair_col = (quadpair_id >> 1);
|
||||
int quad_hilo = (quad_id >> 2) & 1;
|
||||
|
||||
int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 3);
|
||||
int thread_col_offset = quadpair_col;
|
||||
|
||||
Coord<4> thread_idx = make_Coord(0, thread_col_offset, thread_row_offset, 0);
|
||||
|
||||
Coord<4> warp_base = make_Coord(0, warp_id / WarpDelta::kW, warp_id % WarpDelta::kW, 0) *
|
||||
make_Coord(0, 2, kWarpSize, 0);
|
||||
Coord<4> offset = warp_base + thread_idx;
|
||||
|
||||
pointer += strides.template dot<int>(offset);
|
||||
}
|
||||
|
||||
/// Store to the epilogue tile.
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &fragment) const {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
Coord<4> coord =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
|
||||
|
||||
int _offset = coord.template dot<int>(strides);
|
||||
|
||||
Store<typename Fragment::Element, kAccessSize, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType const &>(fragment[w + Iterations::kW * d]),
|
||||
pointer,
|
||||
_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores to the epilogue tile - this iterator does not advance, so increment is null.
|
||||
CUTLASS_DEVICE
|
||||
void store_post_increment(Fragment const &fragment) { store(fragment); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Epilogue shared memory load iterator specialized for Volta's mma.sync.FP16 layout
|
||||
template <
|
||||
/// Shape of warp-level GEMM
|
||||
typename WarpGemmTile_,
|
||||
/// Tiling of warp accumulator elements
|
||||
typename WarpDelta_,
|
||||
/// Data type of accumulator elements
|
||||
typename Scalar_,
|
||||
/// Number of elements loaded per access
|
||||
int AccessSize_>
|
||||
struct Volta884EpilogueSharedLoadIterator<WarpGemmTile_, WarpDelta_, Scalar_, AccessSize_, half> {
|
||||
/// Warp-scoped GEMM tile size
|
||||
typedef WarpGemmTile_ WarpGemmTile;
|
||||
|
||||
/// Tiling of warp elements across threadblock
|
||||
typedef WarpDelta_ WarpDelta;
|
||||
|
||||
/// Scalar data type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Accumulator data type (and layout)
|
||||
typedef half Accumulator;
|
||||
|
||||
/// Number of elements accessed at once
|
||||
static int const kAccessSize = AccessSize_;
|
||||
|
||||
/// Shape of one interleaved mma.sync tile
|
||||
typedef Shape<4, 32, 32> MmaTileShape;
|
||||
|
||||
/// This tile is to be strip-mined with a swizzling function
|
||||
typedef Shape<1, 2 * WarpDelta::kH, WarpGemmTile::kW * WarpDelta::kW / kAccessSize, kAccessSize>
|
||||
Tile;
|
||||
|
||||
/// Index type
|
||||
typedef int Index;
|
||||
|
||||
/// Index type
|
||||
typedef int LongIndex;
|
||||
|
||||
/// Total participating warps
|
||||
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreadCount = kWarpSize * kWarpCount;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1,
|
||||
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
|
||||
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
|
||||
1>
|
||||
Iterations;
|
||||
|
||||
/// Delta between thread-level accesses
|
||||
typedef typename platform::conditional<kThreadCount >= Tile::kW,
|
||||
Shape<1, (kThreadCount / Tile::kW), 1, 1>,
|
||||
Shape<1, 1, kThreadCount, 1> >::type Delta;
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Predicate vector
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Fragment of elements to load
|
||||
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// Elements loaded by one instruction
|
||||
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
|
||||
/// Skew elements
|
||||
static int const kSkew = 2;
|
||||
|
||||
static_assert(!(kSkew % kAccessSize), "Access size must have compatible alignment with skew");
|
||||
|
||||
/// Memory space access
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
|
||||
|
||||
/// Tensor reference type
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Host-side params
|
||||
struct Params {};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer
|
||||
Scalar const *pointer;
|
||||
|
||||
/// Strides
|
||||
Coord<4> strides;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
Volta884EpilogueSharedLoadIterator(Params const &_params, TensorRef const &ref)
|
||||
: pointer(ref.data()),
|
||||
strides(make_Coord(2 * (WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
|
||||
(WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
|
||||
kAccessSize,
|
||||
1)) {
|
||||
// strip-mine this tile
|
||||
Coord<4> offset = make_Coord(0, threadIdx.x / Tile::kW, threadIdx.x % Tile::kW, 0);
|
||||
|
||||
pointer += strides.template dot<int>(offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment from the epilogue tile.
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &fragment) const {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
Coord<4> coord =
|
||||
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kW);
|
||||
|
||||
int _offset = coord.template dot<int>(strides);
|
||||
|
||||
Load<typename Fragment::Element, kAccessSize, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(fragment[w + Iterations::kW * h]), pointer, _offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment - iterator does not actually advance, so increment operation is null.
|
||||
CUTLASS_DEVICE
|
||||
void load_post_increment(Fragment &fragment) { load(fragment); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -106,7 +106,7 @@ struct WmmaGemmEpilogueTraitsHelper {
|
|||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsD,
|
||||
// this parameter helps with swizzling when accum is fp32 and output is fp16
|
||||
sizeof(Accumulator_) / sizeof(typename GemmConfig_::ScalarD)
|
||||
int(sizeof(Accumulator_)) / int(sizeof(typename GemmConfig_::ScalarD))
|
||||
>
|
||||
SharedLoadTileTraits;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -103,7 +103,7 @@ struct WmmaGemmGlobalIteratorCd : public GemmGlobalIteratorCd<TileTraits_, Index
|
|||
Index epilogue_stride_w,
|
||||
Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
this->pointer = pointer;
|
||||
BaseParams::pointer = pointer;
|
||||
// Stride between GEMMs
|
||||
this->stride_d = batch_stride;
|
||||
// Setup the base stride. One "group of threads" per column.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -85,7 +85,10 @@ struct WmmaGemmMultiplyAdd {
|
|||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
|
@ -164,7 +167,10 @@ struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
|
|||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
|
@ -249,7 +255,10 @@ struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
|
|||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
|
@ -329,7 +338,10 @@ struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
|
|||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -38,9 +38,13 @@ namespace cutlass {
|
|||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
if (iterator.valid(d, h, w, c)) {
|
||||
iterator.load_element(reinterpret_cast<typename InputIterator::AccessType &>(
|
||||
|
@ -69,9 +73,13 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme
|
|||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) {
|
||||
typename OutputIterator::FragmentIterator frag_iterator(fragment);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
|
||||
if (iterator.valid(d, h, w, c)) {
|
||||
iterator.store_element(reinterpret_cast<typename OutputIterator::AccessType &>(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
//#include <stdexcept>
|
||||
#include "cutlass/cutlass.h"
|
||||
//#include "tools/util/reference/device/kernel/tensor_foreach.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace layout {
|
||||
namespace thread {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines several helpers
|
||||
namespace detail {
|
||||
|
||||
/// Helper to perform for-each operation
|
||||
template <typename Func, int Rank, int RankRemaining>
|
||||
struct TensorForEachHelper {
|
||||
/// Index of the active rank
|
||||
static int const kActiveRank = Rank - RankRemaining - 1;
|
||||
|
||||
/// Constructor for general rank
|
||||
CUTLASS_DEVICE TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord) {
|
||||
for (int i = 0; i < size.at(kActiveRank); ++i) {
|
||||
coord[kActiveRank] = i;
|
||||
TensorForEachHelper<Func, Rank, RankRemaining - 1>(func, size, coord);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper to perform for-each operation
|
||||
template <typename Func, int Rank>
|
||||
struct TensorForEachHelper<Func, Rank, 0> {
|
||||
/// Index of the active rank
|
||||
static int const kActiveRank = Rank - 1;
|
||||
|
||||
/// Constructor for fastest chaning rank
|
||||
CUTLASS_DEVICE TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord) {
|
||||
for (int i = 0; i < size.at(kActiveRank); ++i) {
|
||||
coord[kActiveRank] = i;
|
||||
func(coord);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterates over the index space of a tensor
|
||||
template <typename Func, int Rank, typename Params>
|
||||
struct TensorForEach {
|
||||
/// Constructor performs the operation.
|
||||
CUTLASS_DEVICE TensorForEach(Coord<Rank> size, Params params = Params()) {
|
||||
Func func(params);
|
||||
Coord<Rank> coord;
|
||||
|
||||
detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, size, coord);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace layout
|
||||
} // namespace cutlass
|
|
@ -0,0 +1,300 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Basic copy routines for tensor views
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/layout/thread/tensor_foreach.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace layout {
|
||||
namespace thread {
|
||||
|
||||
/// Define a functor that performs a copy operation on a tensor.
|
||||
template <typename View_dst, typename View_src>
|
||||
struct CopyFunc {
|
||||
/// Coordinate of index space
|
||||
typedef typename View_dst::TensorCoord TensorCoord;
|
||||
|
||||
View_dst dst;
|
||||
|
||||
View_src src;
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
CopyFunc(View_dst dst, View_src src) : dst(dst), src(src) {}
|
||||
|
||||
/// copy function
|
||||
CUTLASS_DEVICE
|
||||
void operator()(TensorCoord const& coord) {
|
||||
dst.at(coord) = src.at(coord); // uses tensor view's map()
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T_dst, typename T_src, int rank, typename MapFunc_dst, typename MapFunc_src>
|
||||
struct Copy {
|
||||
CUTLASS_DEVICE void copy(cutlass::TensorView<T_dst, rank, MapFunc_dst> dst,
|
||||
cutlass::TensorView<T_src, rank, MapFunc_src> src) {
|
||||
// Define a functor called by TensorForEach<>
|
||||
typedef CopyFunc<cutlass::TensorView<T_dst, rank, MapFunc_dst>,
|
||||
cutlass::TensorView<T_src, rank, MapFunc_src> >
|
||||
CopyFunc;
|
||||
|
||||
// Instantiate on device with TensorViews
|
||||
CopyFunc copy_func(dst, src);
|
||||
|
||||
// Invoke device-side for-each computation on the tensor
|
||||
cutlass::layout::thread::TensorForEach<CopyFunc,
|
||||
rank, // View::kRank
|
||||
CopyFunc>(src.size(), copy_func);
|
||||
}
|
||||
};
|
||||
|
||||
template <int rank>
|
||||
struct Copy<half, half, rank, cutlass::MatrixLayout::RowMajor, cutlass::MatrixLayout::RowMajor> {
|
||||
CUTLASS_DEVICE void copy(cutlass::TensorView<half, rank, cutlass::MatrixLayout::RowMajor> dst,
|
||||
cutlass::TensorView<half, rank, cutlass::MatrixLayout::RowMajor> src) {
|
||||
bool isPacked = dst.isPacked() && src.isPacked();
|
||||
if (isPacked) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < src.capacity(); ++i) {
|
||||
dst.at(i) = src.at(i);
|
||||
}
|
||||
} else {
|
||||
typedef CopyFunc<cutlass::TensorView<half, rank, cutlass::MatrixLayout::RowMajor>,
|
||||
cutlass::TensorView<half, rank, cutlass::MatrixLayout::RowMajor> >
|
||||
CopyFunc;
|
||||
|
||||
// Instantiate on device with TensorViews
|
||||
CopyFunc copy_func(dst, src);
|
||||
|
||||
// Invoke device-side for-each computation on the tensor
|
||||
cutlass::layout::thread::TensorForEach<CopyFunc,
|
||||
rank, // View::kRank
|
||||
CopyFunc>(src.size(), copy_func);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// hgemm swizzle
|
||||
/// Transform a fragment.
|
||||
template <>
|
||||
struct Copy<half, half, 2, cutlass::MatrixLayout::RowMajor, cutlass::MatrixLayout::ColumnMajor> {
|
||||
CUTLASS_DEVICE void copy(
|
||||
cutlass::TensorView<half, 2, cutlass::MatrixLayout::RowMajor> dst,
|
||||
cutlass::TensorView<half, 2, cutlass::MatrixLayout::ColumnMajor> src) {
|
||||
// Expose src/dst as int arrays.
|
||||
int const* src_int = reinterpret_cast<int const*>(src.const_ref().data());
|
||||
int* dst_int = reinterpret_cast<int*>(dst.ref().data());
|
||||
|
||||
int kD = src.size(0);
|
||||
int kDhw = src.size(0) * src.size(1);
|
||||
|
||||
// Transpose the data.
|
||||
// CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < kD; ++d) {
|
||||
// The indices to read two consecutive "rows".
|
||||
int const i0 = 2 * d + 0;
|
||||
int const i1 = 2 * d + 1;
|
||||
|
||||
int a0 = src_int[i0];
|
||||
int a1 = src_int[i1];
|
||||
|
||||
int b0, b1;
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1));
|
||||
|
||||
// The indices to store with "strides".
|
||||
int const j0 = 0 * (kDhw / 2) + d;
|
||||
int const j1 = 1 * (kDhw / 2) + d;
|
||||
|
||||
dst_int[j0] = b0;
|
||||
dst_int[j1] = b1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// igemm swizzle
|
||||
/// Transform a fragment.
|
||||
template <>
|
||||
struct Copy<int8_t,
|
||||
int8_t,
|
||||
2,
|
||||
cutlass::MatrixLayout::RowMajor,
|
||||
cutlass::MatrixLayout::ColumnMajor> {
|
||||
CUTLASS_DEVICE void copy(
|
||||
cutlass::TensorView<int8_t, 2, cutlass::MatrixLayout::RowMajor> dst,
|
||||
cutlass::TensorView<int8_t, 2, cutlass::MatrixLayout::ColumnMajor> src) {
|
||||
// Expose src/dst as int arrays.
|
||||
int const* src_int = reinterpret_cast<int const*>(src.const_ref().data());
|
||||
int* dst_int = reinterpret_cast<int*>(dst.ref().data());
|
||||
|
||||
int kD = src.size(0);
|
||||
int kH = src.size(1);
|
||||
int kWc = src.stride(0);
|
||||
int kHwc = kH * kWc;
|
||||
|
||||
// Transpose the data.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < kH / 4; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < kWc / 4; ++w) {
|
||||
int const i0 = d * (kHwc / 4) + (4 * h + 0) * (kWc / 4) + w;
|
||||
int const i1 = d * (kHwc / 4) + (4 * h + 1) * (kWc / 4) + w;
|
||||
int const i2 = d * (kHwc / 4) + (4 * h + 2) * (kWc / 4) + w;
|
||||
int const i3 = d * (kHwc / 4) + (4 * h + 3) * (kWc / 4) + w;
|
||||
|
||||
int a0 = src_int[i0];
|
||||
int a1 = src_int[i1];
|
||||
int a2 = src_int[i2];
|
||||
int a3 = src_int[i3];
|
||||
|
||||
int b0, b1, b2, b3, c0;
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
|
||||
|
||||
dst_int[i0] = b0;
|
||||
dst_int[i1] = b1;
|
||||
dst_int[i2] = b2;
|
||||
dst_int[i3] = b3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Shape,
|
||||
int Rank,
|
||||
typename DstType,
|
||||
typename DstLayout,
|
||||
typename SrcType,
|
||||
typename SrcLayout>
|
||||
struct Transform {
|
||||
|
||||
typedef Fragment<DstType, ShapeCount<Shape>::kCount> DstFragment;
|
||||
typedef Fragment<SrcType, ShapeCount<Shape>::kCount> SrcFragment;
|
||||
|
||||
/// The input fragment.
|
||||
typedef SrcFragment InputFragment;
|
||||
/// The output fragment.
|
||||
typedef DstFragment OutputFragment;
|
||||
|
||||
CUTLASS_DEVICE void transform(SrcFragment& src, DstFragment& dst) {
|
||||
cutlass::TensorView<DstType, Rank, DstLayout> dstView(
|
||||
&dst[0], // pointer to base of matrix in device memory
|
||||
cutlass::make_Coord(Shape::kD, 1), // stride vector
|
||||
cutlass::make_Coord(Shape::kD,
|
||||
Shape::kH * Shape::kW) // bounds of matrix
|
||||
);
|
||||
cutlass::TensorView<SrcType, Rank, SrcLayout> srcView(
|
||||
&src[0], // pointer to base of matrix in device memory
|
||||
cutlass::make_Coord(Shape::kD, 1), // stride vector
|
||||
cutlass::make_Coord(Shape::kD,
|
||||
Shape::kH * Shape::kW) // bounds of matrix
|
||||
);
|
||||
cutlass::layout::thread::Copy<DstType, SrcType, Rank, DstLayout, SrcLayout> Transformer;
|
||||
Transformer.copy(dstView, srcView);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Shape, int Rank, typename DstLayout, typename SrcLayout>
|
||||
struct Transform<Shape, Rank, half, DstLayout, half, SrcLayout> {
|
||||
typedef Fragment<half, ShapeCount<Shape>::kCount> DstFragment;
|
||||
typedef Fragment<half, ShapeCount<Shape>::kCount> SrcFragment;
|
||||
|
||||
/// The input fragment.
|
||||
typedef SrcFragment InputFragment;
|
||||
/// The output fragment.
|
||||
typedef DstFragment OutputFragment;
|
||||
|
||||
CUTLASS_DEVICE void transform(SrcFragment& src, DstFragment& dst) {
|
||||
cutlass::TensorView<half, Rank, DstLayout> dstView(
|
||||
&dst[0], // pointer to base of matrix in device memory
|
||||
cutlass::make_Coord(Shape::kD, 1), // stride vector
|
||||
cutlass::make_Coord(Shape::kD,
|
||||
Shape::kH * Shape::kW) // bounds of matrix
|
||||
);
|
||||
cutlass::TensorView<half, Rank, SrcLayout> srcView(
|
||||
&src[0], // pointer to base of matrix in device memory
|
||||
cutlass::make_Coord(Shape::kD, 1), // stride vector
|
||||
cutlass::make_Coord(Shape::kD,
|
||||
Shape::kH * Shape::kW) // bounds of matrix
|
||||
);
|
||||
cutlass::layout::thread::Copy<half, half, Rank, DstLayout, SrcLayout> Transformer;
|
||||
Transformer.copy(dstView, srcView);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Shape, int Rank, typename DstLayout, typename SrcLayout>
|
||||
struct Transform<Shape, Rank, int8_t, DstLayout, int8_t, SrcLayout> {
|
||||
typedef Fragment<int8_t, ShapeCount<Shape>::kCount> DstFragment;
|
||||
typedef Fragment<int8_t, ShapeCount<Shape>::kCount> SrcFragment;
|
||||
|
||||
/// The input fragment.
|
||||
typedef SrcFragment InputFragment;
|
||||
/// The output fragment.
|
||||
typedef DstFragment OutputFragment;
|
||||
|
||||
CUTLASS_DEVICE void transform(SrcFragment& src, DstFragment& dst) {
|
||||
cutlass::TensorView<int8_t, Rank, DstLayout> dstView(
|
||||
&dst[0], // pointer to base of matrix in device memory
|
||||
cutlass::make_Coord(Shape::kW * Shape::kC, 1), // stride vector
|
||||
cutlass::make_Coord(Shape::kD,
|
||||
Shape::kH) // bounds of matrix
|
||||
);
|
||||
cutlass::TensorView<int8_t, Rank, SrcLayout> srcView(
|
||||
&src[0], // pointer to base of matrix in device memory
|
||||
cutlass::make_Coord(Shape::kW * Shape::kC, 1), // stride vector
|
||||
cutlass::make_Coord(Shape::kD,
|
||||
Shape::kH) // bounds of matrix
|
||||
);
|
||||
cutlass::layout::thread::Copy<int8_t, int8_t, Rank, DstLayout, SrcLayout> Transformer;
|
||||
Transformer.copy(dstView, srcView);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace thread
|
||||
} // namespace layout
|
||||
} // namespace cutlass
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -66,6 +66,11 @@ struct Load {
|
|||
dst = *reinterpret_cast<AccessType const*>(pointer + offset);
|
||||
}
|
||||
|
||||
/// The clear function.
|
||||
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
|
||||
dst = 0;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -80,6 +85,11 @@ struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_
|
|||
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
reinterpret_cast<uint16_t&>(dst) = reinterpret_cast<uint16_t const*>(&pointer[offset])[0];
|
||||
}
|
||||
|
||||
/// The clear function.
|
||||
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -94,6 +104,10 @@ struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_
|
|||
dst.registers[0] = reinterpret_cast<uint32_t const*>(&pointer[offset])[0];
|
||||
}
|
||||
|
||||
/// The clear function.
|
||||
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
|
||||
dst.registers[0] = uint32_t(0);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -109,6 +123,13 @@ struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_
|
|||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
}
|
||||
|
||||
/// The clear function.
|
||||
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
|
||||
uint2 const zero = make_uint2(0, 0);
|
||||
dst.registers[0] = zero.x;
|
||||
dst.registers[1] = zero.y;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -124,6 +145,13 @@ struct Load<double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 1
|
|||
dst[0] = tmp.x;
|
||||
dst[1] = tmp.y;
|
||||
}
|
||||
|
||||
/// The clear function.
|
||||
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
|
||||
double2 zero = make_double2(0, 0);
|
||||
dst[0] = zero.x;
|
||||
dst[1] = zero.y;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -145,6 +173,15 @@ struct Load<half, 8, Memory_, FragmentElementType::kScalar, half, kStride, 16> {
|
|||
dst.registers[2] = tmp.x;
|
||||
dst.registers[3] = tmp.y;
|
||||
}
|
||||
|
||||
/// The clear function.
|
||||
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
|
||||
int2 zero = make_int2(0,0);
|
||||
dst.registers[0] = zero.x;
|
||||
dst.registers[1] = zero.y;
|
||||
dst.registers[2] = zero.x;
|
||||
dst.registers[3] = zero.y;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
@ -164,6 +201,15 @@ struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_
|
|||
dst.registers[2] = tmp.z;
|
||||
dst.registers[3] = tmp.w;
|
||||
}
|
||||
|
||||
/// The clear function.
|
||||
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
|
||||
uint4 zero = make_uint4(0, 0, 0, 0);
|
||||
dst.registers[0] = zero.x;
|
||||
dst.registers[1] = zero.y;
|
||||
dst.registers[2] = zero.z;
|
||||
dst.registers[3] = zero.w;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -139,7 +139,6 @@ struct MatrixCoord : public Coord<2, int> {
|
|||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines data layouts of various matrix formats usable by TensorRef and other classes.
|
||||
//
|
||||
// The following define classes satisfying the TensorRefMapFunc concept. These must support the
|
||||
|
@ -367,6 +366,12 @@ struct MatrixTransform {
|
|||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Tensor layout
|
||||
namespace TensorLayout {
|
||||
|
||||
enum Kind { kNHWC, kNCHW };
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -186,9 +186,9 @@ struct PredicateVector {
|
|||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
|
||||
|
||||
///
|
||||
/// Copy ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator(PredicateVector const &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
|
||||
ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {}
|
||||
|
||||
/// Pre-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
@ -197,6 +197,13 @@ struct PredicateVector {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator+=(int offset) {
|
||||
bit_ += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Pre-decrement
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator--() {
|
||||
|
@ -204,6 +211,13 @@ struct PredicateVector {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Decrement
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator-=(int offset) {
|
||||
bit_ -= offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Post-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator++(int) {
|
||||
|
@ -220,6 +234,22 @@ struct PredicateVector {
|
|||
return ret;
|
||||
}
|
||||
|
||||
/// Iterator advances by some amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator+(int offset) {
|
||||
ConstIterator ret(*this);
|
||||
ret.bit_ += offset;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Iterator recedes by some amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator-(int offset) {
|
||||
ConstIterator ret(*this);
|
||||
ret.bit_ -= offset;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Returns true if iterators point to the same bit
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; }
|
||||
|
@ -230,7 +260,15 @@ struct PredicateVector {
|
|||
|
||||
/// Dereferences iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator*() const { return vec_[bit_]; }
|
||||
bool operator*() const { return vec_.at(bit_); }
|
||||
|
||||
/// Gets the bit at the pointed to location
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool get() const { return vec_.at(bit_); }
|
||||
|
||||
/// Gets the bit at the pointed to location
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool at() const { return vec_.at(bit_); }
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -252,7 +290,7 @@ struct PredicateVector {
|
|||
|
||||
/// Constructs an iterator from a PredicateVector
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator(PredicateVector &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
|
||||
Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {}
|
||||
|
||||
/// Pre-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
@ -261,6 +299,13 @@ struct PredicateVector {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator &operator+=(int offset) {
|
||||
bit_ += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Pre-decrement
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator &operator--() {
|
||||
|
@ -268,6 +313,13 @@ struct PredicateVector {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Decrement
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator &operator-=(int offset) {
|
||||
bit_ -= offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Post-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator operator++(int) {
|
||||
|
@ -284,6 +336,22 @@ struct PredicateVector {
|
|||
return ret;
|
||||
}
|
||||
|
||||
/// Iterator advances by some amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator operator+(int offset) {
|
||||
Iterator ret(*this);
|
||||
ret.bit_ += offset;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Iterator recedes by some amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator operator-(int offset) {
|
||||
ConstIterator ret(*this);
|
||||
ret.bit_ -= offset;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Returns true if iterators point to the same bit
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Iterator const &it) const { return bit_ == it.bit_; }
|
||||
|
@ -294,11 +362,15 @@ struct PredicateVector {
|
|||
|
||||
/// Gets the bit at the pointed to location
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool get() { return vec_[bit_]; }
|
||||
bool get() { return vec_.at(bit_); }
|
||||
|
||||
/// Gets the bit at the pointed to location
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool at() const { return vec_.at(bit_); }
|
||||
|
||||
/// Dereferences iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator*() const { return vec_[bit_]; }
|
||||
bool operator*() const { return at(); }
|
||||
|
||||
/// Sets the bit at the pointed to location
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -80,14 +80,16 @@ struct BatchedReduction {
|
|||
|
||||
typename Traits::ScalarA inRegs[Traits::maxInReg];
|
||||
typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg];
|
||||
|
||||
#pragma unroll
|
||||
for (int subTile = 0; subTile < tileSize; subTile += subTileSize) {
|
||||
int tileOffset = subTileBase + subTileOffset;
|
||||
// Init AccumRegs
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++)
|
||||
AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(0.0f);
|
||||
// Fetch c0
|
||||
typename Traits::ScalarAccum c0[Traits::ThreadShape::kW];
|
||||
#pragma unroll
|
||||
for (int i = 0; i< Traits::ThreadShape::kW; i++)
|
||||
c0[i] = static_cast<typename Traits::ScalarAccum>(params.d_c[tileOffset + i]);
|
||||
|
||||
|
@ -131,11 +133,13 @@ struct BatchedReduction {
|
|||
template<bool ThreadShapeMultiple2>
|
||||
CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output) {
|
||||
if (ThreadShapeMultiple2 == true) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW / 2; i++) {
|
||||
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 2>(&accum[2 * i], &old[2 * i], &output[2 * i]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
|
||||
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 1>(&accum[i], &old[i], &output[i]);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -72,6 +72,23 @@ struct Shape {
|
|||
static int const kC = kC_;
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* @brief A Shape implementing \ref layout_concept describing the dimensions of a cube.
|
||||
* @concept{layout_concept}
|
||||
*/
|
||||
template <int kH_, int kW_>
|
||||
struct Shape<1, kH_, kW_, 1> {
|
||||
/// The depth of the cube.
|
||||
static int const kD = 1;
|
||||
/// The height of the cube.
|
||||
static int const kH = kH_;
|
||||
/// The width of the cube.
|
||||
static int const kW = kW_;
|
||||
/// The number of scalars per element.
|
||||
static int const kC = 1;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compute derived counted of a \ref layout_concept based class
|
||||
*/
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -63,12 +63,7 @@ struct IdentityTensorMapFunc {
|
|||
and assumptions about vectorizing memory accesses throughout CUTLASS. It also matches various
|
||||
BLAS conventions in which only the "leading dimension" or most significant stride of a rank=2
|
||||
matrix is provided.
|
||||
|
||||
This does affect the ability of constructing arbitrary "sparse" 2-D matrices in memory where all
|
||||
stride elements are > 1. This can be overcome by defining a custom mapping function and a
|
||||
StorageRank of 3 or more.
|
||||
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
(These examples use helpers for matrix layouts defined in cutlass/matrix_traits.h)
|
||||
|
@ -85,7 +80,7 @@ struct IdentityTensorMapFunc {
|
|||
|
||||
TensorRef<int8_t, 2, MatrixLayout::ColumnMajorInterleaved<32> > C;
|
||||
|
||||
4. Defining a sparse matrix with arbitrary strides in each dimension
|
||||
4. Defining a matrix with arbitrary strides in each dimension
|
||||
|
||||
struct ContiguousLayout {
|
||||
|
||||
|
@ -545,6 +540,10 @@ class TensorRef<Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_> {
|
|||
CUTLASS_HOST_DEVICE
|
||||
Storage * data() const { return ptr_; }
|
||||
|
||||
/// Returns the pointer to referenced data at the given coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage * data(TensorCoord const& coord) const { return ptr_ + offset(coord); }
|
||||
|
||||
/// Returns the stride of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
StorageCoord stride() const {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -75,7 +75,7 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
|
|||
LongIndex_> ConstTensorRef;
|
||||
|
||||
/// Base tensor reference
|
||||
typedef Base TensorRef;
|
||||
typedef Base TensorRef_t;
|
||||
|
||||
/// Storage type
|
||||
typedef typename Base::Storage Storage;
|
||||
|
@ -84,14 +84,14 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
|
|||
typedef typename Base::Index Index;
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
typedef typename TensorRef::TensorCoord TensorCoord;
|
||||
typedef typename TensorRef_t::TensorCoord TensorCoord;
|
||||
|
||||
/// Coordinate in storage n-D array
|
||||
typedef typename TensorRef::StorageCoord StorageCoord;
|
||||
typedef typename TensorRef_t::StorageCoord StorageCoord;
|
||||
|
||||
/// Stride vector in storage coordinate space
|
||||
/// Least significant stride is = 1 and not stored
|
||||
typedef typename TensorRef::StrideVector StrideVector;
|
||||
typedef typename TensorRef_t::StrideVector StrideVector;
|
||||
|
||||
/// TensorView of constant value
|
||||
typedef TensorView<
|
||||
|
@ -115,11 +115,8 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
|
|||
/// Type used to compute the offset of an element to the base of a tensor
|
||||
typedef typename Base::LongIndex Offset_t;
|
||||
|
||||
/// Base class
|
||||
typedef TensorRef TensorRef_t;
|
||||
|
||||
/// TensorRef to const-valued type
|
||||
typedef typename TensorRef::ConstTensorRef ConstTensorRef_t;
|
||||
typedef typename TensorRef_t::ConstTensorRef ConstTensorRef_t;
|
||||
|
||||
private:
|
||||
//
|
||||
|
@ -195,14 +192,46 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef pointing to the first element of the tensor.
|
||||
/// Determines the order of dims of the tensor (e.g., CHW versus HWC)
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef ref() const {
|
||||
return TensorRef(*this);
|
||||
void getStrideOrder(int order[]) const {
|
||||
for (int i = 0; i < Rank_; i++) order[i] = i;
|
||||
// Bubble sort
|
||||
for (int start = 0; start < Rank_ - 1; start++) {
|
||||
for (int i = start; i < Rank_ - 1; i++) {
|
||||
if (this->stride(order[i]) < this->stride(order[i + 1])) {
|
||||
int temp = order[i];
|
||||
order[i] = order[i + 1];
|
||||
order[i + 1] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
// post-condition: this->stride(ord[i]) >= this->stride(ord[i+1]) for i from [0,Rank_-2]
|
||||
}
|
||||
|
||||
/// Determines if the values in the tensor are contiguous
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool isPacked() const {
|
||||
if (Rank_ <= 0) return true;
|
||||
int ord[Rank_];
|
||||
getStrideOrder(ord);
|
||||
// first check if the slowest dimension has a stride of 1
|
||||
if (this->stride(ord[Rank_ - 1]) != 1) return false;
|
||||
// now check that there are no gaps between strides
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank_; i++)
|
||||
if (this->stride(ord[i]) != this->stride(ord[i + 1]) * size_[ord[i + 1]]) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef pointing to the first element of the tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef_t ref() const {
|
||||
return TensorRef_t(*this);
|
||||
}
|
||||
|
||||
/// Returns a TensorRef_t pointing to the first element of the tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstTensorRef const_ref() const {
|
||||
return ConstTensorRef(*this);
|
||||
}
|
||||
|
@ -238,22 +267,22 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
|
|||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
/// Returns a TensorRef_t offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView& operator+=(TensorCoord const& b) {
|
||||
this->add_pointer_offset(this->offset(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
/// Returns a TensorRef_t offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView operator-(TensorCoord const& b) const {
|
||||
TensorRef result(*this);
|
||||
TensorRef_t result(*this);
|
||||
result.add_pointer_offset(-this->offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
/// Returns a TensorRef_t offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView& operator-=(TensorCoord const& b) {
|
||||
this->add_pointer_offset(-this->offset(b));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -48,7 +48,7 @@ struct TileAllocation {
|
|||
typedef Scalar_ Scalar;
|
||||
|
||||
/// The actual storage (may differ from the scalar type)
|
||||
typedef typename StorageType<sizeof(Scalar)>::Type Storage;
|
||||
typedef typename StorageType<int(sizeof(Scalar))>::Type Storage;
|
||||
|
||||
/// Size of the allocation in units of scalars
|
||||
typedef Shape_ Shape;
|
||||
|
@ -165,4 +165,62 @@ struct ZipTileAllocation {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Manages a pair of tile allocations as if they are one allocation
|
||||
template <typename First_, typename Second_, typename Third_>
|
||||
struct ZipTileAllocationTriple {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// First tensor allocation
|
||||
typedef First_ First;
|
||||
|
||||
/// Second tensor allocation
|
||||
typedef Second_ Second;
|
||||
|
||||
/// meta data tensor allocation
|
||||
typedef Third_ Third;
|
||||
|
||||
/// Defines the tensor reference for this allocation
|
||||
typedef Zip3TensorRef<typename First::TensorRef,
|
||||
typename Second::TensorRef,
|
||||
typename Third::TensorRef> TensorRef;
|
||||
|
||||
/// Defines the tensor reference for this allocation
|
||||
typedef Zip3TensorRef<typename First::ConstTensorRef,
|
||||
typename Second::ConstTensorRef,
|
||||
typename Third::ConstTensorRef>
|
||||
ConstTensorRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// First tensor allocation
|
||||
First first;
|
||||
|
||||
/// Second tensor allocation
|
||||
Second second;
|
||||
|
||||
/// meta data tensor
|
||||
Third third;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a TensorRef object pointing to the data
|
||||
CUTLASS_DEVICE
|
||||
TensorRef reference() {
|
||||
return TensorRef(first.reference(), second.reference(), third.reference());
|
||||
}
|
||||
|
||||
/// Returns a TensorRef object pointing to the data
|
||||
CUTLASS_DEVICE
|
||||
ConstTensorRef reference() const {
|
||||
return ConstTensorRef(first.reference(), second.reference(), third.reference());
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -293,12 +293,10 @@ struct TileIteratorBase {
|
|||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
stride_w = _stride_w;
|
||||
|
||||
inc_w = stride_w * Delta::kW;
|
||||
inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
inc_d = stride_h * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) -
|
||||
stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
|
||||
inc_advance = 0;
|
||||
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
|
@ -740,9 +738,13 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
|||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
if (*pred_it) {
|
||||
load_element(
|
||||
|
@ -789,8 +791,11 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
|||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
|
||||
}
|
||||
|
@ -1076,7 +1081,6 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
|||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params), stage(0) {
|
||||
thread_offset = thread_offset_func();
|
||||
|
||||
params.pointer += (block_offset[0] + thread_offset[0]) * params.stride_d +
|
||||
(block_offset[1] + thread_offset[1]) * params.stride_h +
|
||||
(block_offset[2] + thread_offset[2]) * params.stride_w;
|
||||
|
@ -1148,10 +1152,13 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
|||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) {
|
||||
FragmentConstIterator frag_iterator(fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
if (*pred_it) {
|
||||
store_element(
|
||||
|
@ -1213,9 +1220,13 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
|||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
if (*pred_it) {
|
||||
load_element(
|
||||
|
@ -1262,8 +1273,11 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
|||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -272,6 +272,9 @@ struct PredicatedTileLoadStream : public TileLoadStream<Iterator_, Transformer_>
|
|||
|
||||
/// Parameters object used to construct generic load stream
|
||||
typedef typename Base::Params Params;
|
||||
|
||||
///
|
||||
typedef typename Iterator::Scalar Scalar;
|
||||
|
||||
//
|
||||
// Data members
|
||||
|
@ -331,6 +334,9 @@ struct PredicatedTileStoreStream : public TileStoreStream<Iterator_, Transformer
|
|||
/// Parameters object used to construct generic load stream
|
||||
typedef typename Base::Params Params;
|
||||
|
||||
///
|
||||
typedef typename Iterator::Scalar Scalar;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -131,7 +131,7 @@ struct TileTraitsContiguousMajor {
|
|||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tiling in which warps rake across the contiguous dimension
|
||||
template <typename Tile_, int Threads>
|
||||
template <typename Tile_, int Threads, int AccessSize = 1>
|
||||
struct TileTraitsWarpRake {
|
||||
/// Shape of tile
|
||||
typedef Tile_ Tile;
|
||||
|
@ -163,10 +163,10 @@ struct TileTraitsWarpRake {
|
|||
typedef Shape<1, kWarpsStrided, kWarpsContiguous * kWarpSize> ThreadShape;
|
||||
|
||||
/// The same warp rakes along the contiguous dimension
|
||||
typedef Shape<1, kWarpsStrided, kWarpSize> Delta;
|
||||
typedef Shape<1, kWarpsStrided, kWarpSize * AccessSize> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, Tile::kH / Delta::kH, Tile::kW / ThreadShape::kW> Iterations;
|
||||
typedef Shape<1, Tile::kH / Delta::kH, (Tile::kW / AccessSize) / ThreadShape::kW> Iterations;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
|
@ -182,7 +182,7 @@ struct TileTraitsWarpRake {
|
|||
int warp_w = (warp % kWarpsContiguous);
|
||||
int warp_h = (warp / kWarpsContiguous);
|
||||
|
||||
return make_Coord(0, warp_h, lane + kWarpSpanContiguous * warp_w, 0);
|
||||
return make_Coord(0, warp_h, AccessSize * (lane + kWarpSpanContiguous * warp_w), 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -84,7 +84,6 @@
|
|||
* - \p aligned_storage
|
||||
*
|
||||
* (4) Functions and types that are STL-like (but aren't in the STL):
|
||||
* - \p TODO: min and max functors?
|
||||
*
|
||||
* The idea is that, as we drop support for older compilers, we can simply #define
|
||||
* the \p __NV_STD_XYZ macros and \p platform namespace to alias their C++
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -259,6 +259,40 @@ union Vector<uint4_t, kLanes_> {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Vector definition for 4-bit signed integer datatype
|
||||
template <int kLanes_>
|
||||
union Vector<int8_t, kLanes_> {
|
||||
/// The scalar type.
|
||||
typedef int8_t Scalar;
|
||||
|
||||
/// The number of elements in the vector.
|
||||
enum { kLanes = kLanes_ };
|
||||
/// The size of the vector.
|
||||
enum { kVectorSize = kLanes };
|
||||
/// The number of registers needed to store the vector.
|
||||
enum { kRegisters = kVectorSize < 4 ? 1 : (kVectorSize+3) / 4 };
|
||||
|
||||
// static_assert((kLanes >= 2) && !(kLanes % 2),
|
||||
// "May only construct vectors of int8_t that are multiples of 8 bits.");
|
||||
|
||||
/// The aligned storage to make sure we have good alignment.
|
||||
AlignedStruct<kVectorSize> aligned_;
|
||||
/// The data in registers.
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Default Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Vector() {}
|
||||
/// Constructor to convert from uint32_t type
|
||||
CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; }
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_HOST_DEVICE int operator[](uint32_t i) const {
|
||||
return (registers[i / 4] >> (i % 4 * 8) & 0xff);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_>
|
||||
CUTLASS_HOST_DEVICE void make_zero(Scalar_& x) {
|
||||
x = Scalar_(0);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -30,6 +30,10 @@
|
|||
#if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700)
|
||||
#define CUTLASS_USE_WMMA_API
|
||||
|
||||
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 720)
|
||||
#define CUTLASS_USE_INT_WMMA
|
||||
#endif
|
||||
|
||||
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750)
|
||||
#define CUTLASS_USE_SUBBYTE_WMMA
|
||||
#endif
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -72,6 +72,57 @@ ZipTensorRef<First, Second> make_ZipTensorRef(First const &first, Second const &
|
|||
return ZipTensorRef<First, Second>(first, second);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Any simple way to do so?
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename First_, typename Second_, typename Third_>
|
||||
struct Zip3TensorRef {
|
||||
/// First tensor ref
|
||||
typedef First_ First;
|
||||
|
||||
/// Second tensor ref
|
||||
typedef Second_ Second;
|
||||
|
||||
/// Third tensor ref
|
||||
typedef Third_ Third;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// First TensorRef
|
||||
First first;
|
||||
|
||||
/// Second TensorRef
|
||||
Second second;
|
||||
|
||||
/// Third TensorRef
|
||||
Third third;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Zip3TensorRef() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Zip3TensorRef(First const& _first, Second const& _second, Third const& _third) :
|
||||
first(_first), second(_second), third(_third) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constructs a ZipTensorRef
|
||||
template <typename First, typename Second, typename Third>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Zip3TensorRef<First, Second, Third> make_Zip3TensorRef(First const &first,
|
||||
Second const &second,
|
||||
Third const &third) {
|
||||
return Zip3TensorRef<First, Second, Third>(first, second, third);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -47,6 +47,9 @@ class ZipTileIterator {
|
|||
|
||||
/// Second iterator type
|
||||
typedef Second_ Second;
|
||||
|
||||
///
|
||||
typedef typename First::Scalar Scalar;
|
||||
|
||||
/// Params object
|
||||
struct Params {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue