Cutlass 1.3 Release (#42)

CUTLASS 1.3 Release
- Efficient GEMM kernel targeting Volta Tensor Cores via mma.sync instruction added in CUDA 10.1.
This commit is contained in:
Andrew Kerr 2019-03-20 10:49:17 -07:00 committed by GitHub
parent 19a9d64e3c
commit 877bdcace6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
256 changed files with 16930 additions and 802 deletions

View File

@ -1,7 +1,7 @@
# NVIDIA CUTLASS Changelog
## [1.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.1) (2018-12-19)
* Resolved issue with sm50 and sm52 architectures
## [1.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.0) (2019-03-20)
* Efficient GEMM kernel targeting Volta Tensor Cores via `mma.sync` instruction added in CUDA 10.1.
## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26)
* Parallelized reductions across threadblocks ("Split-K")

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@ -20,7 +20,7 @@
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.3.0)
cmake_minimum_required(VERSION 3.3.0 FATAL_ERROR)
set(CUTLASS_LANGUAGES CXX)
@ -36,7 +36,8 @@ else()
# FindCUDA fails to detect VS 2017 due to a changed directory format of the toolkits.
# For this configuration we need CMake >= 3.9.0 to use the native CUDA support.
if (WIN32 AND MSVC_VERSION GREATER 1800)
message(FATAL_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
message(SEND_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
cmake_minimum_required(VERSION 3.9.0 FATAL_ERROR)
endif()
# Fall back to the FindCUDA version to create an executable with CUDA files
@ -52,7 +53,11 @@ if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
endif()
find_package(CUDA)
find_package(CUDA REQUIRED)
include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
# Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include
# paths by default, so we add it explicitly here.
find_package(Doxygen QUIET)
###################################################################################################
@ -61,9 +66,18 @@ find_package(Doxygen QUIET)
#
###################################################################################################
find_library(CUBLAS_LIBRARY cublas HINTS
#
# Conditionally enable cuBLAS
#
set(CUTLASS_ENABLE_CUBLAS ON CACHE BOOL "Enable CUTLASS Tests to build with cuBLAS library.")
if(CUTLASS_ENABLE_CUBLAS)
find_library(CUBLAS_LIBRARY cublas HINTS
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
endif()
# By default we want to build in Release mode to ensure that we're getting best performance
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
@ -78,26 +92,56 @@ if(WIN32)
endif()
if (WIN32)
# Enable more warnings and treat as errors
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
# Enable more warnings and treat as errors
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
# Disable warning on Unicode characters
string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
# Disable warning on Unicode characters
string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
# Verbose option
if (${CUTLASS_NVCC_VERBOSE})
string(APPEND NVCC_FLAGS " -v")
endif()
# Verbose option
if (${CUTLASS_NVCC_VERBOSE})
string(APPEND NVCC_FLAGS " -v")
endif()
endif(WIN32)
set(CUTLASS_NVCC_ARCHS "50;60;61;70;75" CACHE STRING "The SM architectures to build code for.")
set(CUTLASS_NVCC_ARCHS_DEFAULT "")
if(NOT CUDA_VERSION VERSION_LESS 7.5)
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 50)
endif()
if(NOT CUDA_VERSION VERSION_LESS 8.0)
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 60 61)
endif()
if(NOT CUDA_VERSION VERSION_LESS 9.0)
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 70)
endif()
if(NOT CUDA_VERSION VERSION_LESS 9.2)
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 72)
endif()
if(NOT CUDA_VERSION VERSION_LESS 10.0)
list(APPEND CUTLASS_NVCC_ARCHS_DEFAULT 75)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_DEFAULT} CACHE STRING "The SM architectures to build code for.")
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
if (CUDA_VERSION VERSION_LESS 10.1)
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT OFF)
else()
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
endif()
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
"Enable PTX mma instruction for collective matrix multiply operations.")
set(CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST ${CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST} CACHE BOOL
"Enable more kernels instantiated in the perf suite. This might result in longer compiler time. ")
#
# NOTE: running with asan and CUDA requires the following environment variable:
#
@ -131,6 +175,18 @@ foreach(ARCH ${CUTLASS_NVCC_ARCHS})
endif()
endforeach()
if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1")
endif()
if (CUTLASS_ENABLE_CUBLAS)
string(APPEND NVCC_FLAGS " -DCUTLASS_ENABLE_CUBLAS=1")
endif()
if (CUTLASS_EXHAUSTIVE_PERFORMANCE_TEST)
add_definitions(-DEXHAUSTIVE_PROF)
endif()
if (CUTLASS_NVCC_KEEP)
string(APPEND NVCC_FLAGS " -keep")
endif()
@ -174,6 +230,7 @@ file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
file(GLOB CUTLASS_REDUCTION RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/reduction/*.h )
file(GLOB CUTLASS_LAYOUT_THREAD RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/layout/thread/*.h)
###################################################################################################
#
@ -185,16 +242,24 @@ source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
source_group("cutlass\\device" FILES ${CUTLASS_DEVICE})
source_group("cutlass\\reduction" FILES ${CUTLASS_REDUCTION})
source_group("cutlass\\layout\\thread" FILES ${CUTLASS_LAYOUT_THREAD})
source_group("cutlass" FILES ${CUTLASS_CORE})
add_library(CUTLASS INTERFACE)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
# Special policy introduced in CMake 3.13
if (POLICY CMP0076)
cmake_policy(SET CMP0076 NEW)
endif()
target_sources(CUTLASS INTERFACE
${CUTLASS_GEMM}
${CUTLASS_UTIL}
${CUTLASS_DEVICE}
${CUTLASS_CORE}
${CUTLASS_REDUCTION}
${CUTLASS_LAYOUT_THREAD}
)
target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
@ -206,6 +271,7 @@ add_custom_target(cutlass_ide SOURCES
${CUTLASS_DEVICE}
${CUTLASS_CORE}
${CUTLASS_REDUCTION}
${CUTLASS_LAYOUT_THREAD}
)
# Doxygen is available. Generate documentation
if (DOXYGEN_FOUND)

View File

@ -14,7 +14,7 @@ CUTLASS core components, and to identify their role in implementing GEMM computa
# <a name="S-design-patterns"></a> 1. Design Patterns
CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a
flexible composition that an be easily applied to solve new problems related to Deep Learning and
flexible composition that can be easily applied to solve new problems related to Deep Learning and
linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given
a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several
design patterns are necessary to yield a composable structure while also satisfying these performance
@ -31,7 +31,7 @@ CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvla
## <a name="S-patterns-tiles-iterators"></a> Tiles and Iterators
Efficient dense linear algebra computations emphasize data movement to match the execution of mathemtical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
Efficient dense linear algebra computations emphasize data movement to match the execution of mathematical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
specifying element type, size, and data layout. CUTLASS refers to subpartitions as _tiles_.
_Iterators_ are familiar design patterns in C++ that provide an abstraction for accessing individual
@ -353,7 +353,7 @@ An example of splitK usage can be found [here](examples/06_splitK_gemm/splitK_ge
# Copyright
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted

View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 1.2
# CUTLASS 1.3
_CUTLASS 1.2 - October 2018_
_CUTLASS 1.3.0 - March 2019_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
@ -20,13 +20,18 @@ multiply-accumulate abstractions for 8-bit integer, half-precision floating
point (FP16), single-precision floating point (FP32), and double-precision floating
point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targeting
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
and beyond.
and beyond. Even faster performance on Volta is possible via direct access to
Volta Tenor Cores via `mma.sync` (added in CUDA 10.1).
CUTLASS 1.2 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
CUTLASS 1.3 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
We describe the structure of an efficient GEMM in our talk at the
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
# What's New in CUTLASS 1.3
_March 2019_
* CUTLASS 1.3 includes an efficient GEMM implementation with the `mma.sync` instruction added in CUDA 10.1.
# What's New in CUTLASS 1.2
_October 2018_
* [Parallelized Reductions](CUTLASS.md#parallel-reductions-across-gemm-k)
@ -63,8 +68,8 @@ when compiled with CUDA 10.0.
# Compatibility
CUTLASS performs best when compiled with the [CUDA 10.0 Toolkit](ttps://developer.nvidia.com/cuda-toolkit).
It is compatible with CUDA 9.0, 9.1, and 9.2, but these versions of the CUDA Toolkit do not support new Turing WMMA features.
CUTLASS performs best when compiled with the [CUDA 10.1 Toolkit](ttps://developer.nvidia.com/cuda-toolkit).
It is also compatible with CUDA 9.0, 9.1, 9.2, and 10.0.
We have tested the following environments.
@ -77,7 +82,7 @@ We have tested the following environments.
| Ubuntu 18.04 | GCC 7.3.0 |
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU.
any Maxwell-, Pascal-, Volta-, and Turing-architecture NVIDIA GPUs.
|**GPU**|
|---|
@ -220,6 +225,9 @@ Program usage:
# Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 --kernels=sgemm_nn,igemm_nn
# Executes GEMM kernel on Volta Tensor Cores
$ ./tools/test/perf/cutlass_perf_test --kernels=s884gemm_nt
```
# About
@ -230,7 +238,7 @@ CUTLASS is released by NVIDIA Corporation as Open Source software under the
# Copyright
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted
@ -253,4 +261,3 @@ Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

380
cutlass/arch/mma.h Normal file
View File

@ -0,0 +1,380 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates wrapping direct issue of MMA instructions to Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/shape.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace arch {
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Specifies internal data type for computation
struct ComputeType {
enum Kind {
kBegin,
kDefault, /// Compute type implied by operand and accumulator types
kEnd
};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Direct wrapper for native MMA instruction
template <
/// Warp-level matrix multiply-accumulate operation
typename WmmaTile,
/// Layout of A multiplicand
MatrixLayout::Kind LayoutA,
/// Data type of A multiplicand
typename ScalarA,
/// Layout of B multiplicand
MatrixLayout::Kind LayoutB,
/// Data type of A multiplicand
typename ScalarB,
/// Data type of accumulators
typename ScalarC,
/// Specifies particular compute type, overriding data types of operands
ComputeType::Kind ComputeTy>
inline __device__ void mma(ScalarA const A[], ScalarB const B[], ScalarC const C[], ScalarC D[]);
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// 16x16x4
//
//
// FP16 accumulation
//
/// Volta mma.sync instruction
template <>
inline __device__ void mma<Shape<4, 16, 16>,
MatrixLayout::kRowMajor,
half,
MatrixLayout::kColumnMajor,
half,
half,
ComputeType::kDefault>(half const a[],
half const b[],
half const c[],
half d[]) {
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
unsigned const *A = reinterpret_cast<unsigned const *>(a);
unsigned const *B = reinterpret_cast<unsigned const *>(b);
unsigned const *C = reinterpret_cast<unsigned const *>(c);
unsigned *D = reinterpret_cast<unsigned *>(d);
asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
#endif
}
/// Volta mma.sync instruction
template <>
inline __device__ void mma<Shape<4, 16, 16>,
MatrixLayout::kColumnMajor,
half,
MatrixLayout::kColumnMajor,
half,
half,
ComputeType::kDefault>(half const a[],
half const b[],
half const c[],
half d[]) {
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
unsigned const *A = reinterpret_cast<unsigned const *>(a);
unsigned const *B = reinterpret_cast<unsigned const *>(b);
unsigned const *C = reinterpret_cast<unsigned const *>(c);
unsigned *D = reinterpret_cast<unsigned *>(d);
asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
#endif
}
/// Volta mma.sync instruction
template <>
inline __device__ void mma<Shape<4, 16, 16>,
MatrixLayout::kRowMajor,
half,
MatrixLayout::kRowMajor,
half,
half,
ComputeType::kDefault>(half const a[],
half const b[],
half const c[],
half d[]) {
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
unsigned const *A = reinterpret_cast<unsigned const *>(a);
unsigned const *B = reinterpret_cast<unsigned const *>(b);
unsigned const *C = reinterpret_cast<unsigned const *>(c);
unsigned *D = reinterpret_cast<unsigned *>(d);
asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
#endif
}
/// Volta mma.sync instruction
template <>
inline __device__ void mma<Shape<4, 16, 16>,
MatrixLayout::kColumnMajor,
half,
MatrixLayout::kRowMajor,
half,
half,
ComputeType::kDefault>(half const a[],
half const b[],
half const c[],
half d[]) {
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
unsigned const *A = reinterpret_cast<unsigned const *>(a);
unsigned const *B = reinterpret_cast<unsigned const *>(b);
unsigned const *C = reinterpret_cast<unsigned const *>(c);
unsigned *D = reinterpret_cast<unsigned *>(d);
asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};"
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#else
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
#endif
}
//
// FP32 accumulation
//
/// Volta mma.sync instruction
template <>
inline __device__ void mma<Shape<4, 16, 16>,
MatrixLayout::kRowMajor,
half,
MatrixLayout::kColumnMajor,
half,
float,
ComputeType::kDefault>(half const a[],
half const b[],
float const C[],
float D[]) {
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
unsigned const *A = reinterpret_cast<unsigned const *>(a);
unsigned const *B = reinterpret_cast<unsigned const *>(b);
asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};"
: "=f"(D[0]),
"=f"(D[1]),
"=f"(D[2]),
"=f"(D[3]),
"=f"(D[4]),
"=f"(D[5]),
"=f"(D[6]),
"=f"(D[7])
: "r"(A[0]),
"r"(A[1]),
"r"(B[0]),
"r"(B[1]),
"f"(C[0]),
"f"(C[1]),
"f"(C[2]),
"f"(C[3]),
"f"(C[4]),
"f"(C[5]),
"f"(C[6]),
"f"(C[7]));
#else
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
#endif
}
/// Volta mma.sync instruction
template <>
inline __device__ void mma<Shape<4, 16, 16>,
MatrixLayout::kColumnMajor,
half,
MatrixLayout::kColumnMajor,
half,
float,
ComputeType::kDefault>(half const a[],
half const b[],
float const C[],
float D[]) {
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
unsigned const *A = reinterpret_cast<unsigned const *>(a);
unsigned const *B = reinterpret_cast<unsigned const *>(b);
asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};"
: "=f"(D[0]),
"=f"(D[1]),
"=f"(D[2]),
"=f"(D[3]),
"=f"(D[4]),
"=f"(D[5]),
"=f"(D[6]),
"=f"(D[7])
: "r"(A[0]),
"r"(A[1]),
"r"(B[0]),
"r"(B[1]),
"f"(C[0]),
"f"(C[1]),
"f"(C[2]),
"f"(C[3]),
"f"(C[4]),
"f"(C[5]),
"f"(C[6]),
"f"(C[7]));
#else
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
#endif
}
/// Volta mma.sync instruction
template <>
inline __device__ void mma<Shape<4, 16, 16>,
MatrixLayout::kRowMajor,
half,
MatrixLayout::kRowMajor,
half,
float,
ComputeType::kDefault>(half const a[],
half const b[],
float const C[],
float D[]) {
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
unsigned const *A = reinterpret_cast<unsigned const *>(a);
unsigned const *B = reinterpret_cast<unsigned const *>(b);
asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};"
: "=f"(D[0]),
"=f"(D[1]),
"=f"(D[2]),
"=f"(D[3]),
"=f"(D[4]),
"=f"(D[5]),
"=f"(D[6]),
"=f"(D[7])
: "r"(A[0]),
"r"(A[1]),
"r"(B[0]),
"r"(B[1]),
"f"(C[0]),
"f"(C[1]),
"f"(C[2]),
"f"(C[3]),
"f"(C[4]),
"f"(C[5]),
"f"(C[6]),
"f"(C[7]));
#else
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
#endif
}
/// Volta mma.sync instruction
template <>
inline __device__ void mma<Shape<4, 16, 16>,
MatrixLayout::kColumnMajor,
half,
MatrixLayout::kRowMajor,
half,
float,
ComputeType::kDefault>(half const a[],
half const b[],
float const C[],
float D[]) {
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
unsigned const *A = reinterpret_cast<unsigned const *>(a);
unsigned const *B = reinterpret_cast<unsigned const *>(b);
asm volatile ("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
"{%12,%13,%14,%15,%16,%17,%18,%19};"
: "=f"(D[0]),
"=f"(D[1]),
"=f"(D[2]),
"=f"(D[3]),
"=f"(D[4]),
"=f"(D[5]),
"=f"(D[6]),
"=f"(D[7])
: "r"(A[0]),
"r"(A[1]),
"r"(B[0]),
"r"(B[1]),
"f"(C[0]),
"f"(C[1]),
"f"(C[2]),
"f"(C[3]),
"f"(C[4]),
"f"(C[5]),
"f"(C[6]),
"f"(C[7]));
#else
CUTLASS_ASSERT(0); // Collective matrix multiply instruction requires CUTLASS_ENABLE_TENSOR_CORE_MMA=1
#endif
}
} // namespace arch
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -29,11 +29,12 @@
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CUTLASS_MAJOR 1
#define CUTLASS_MINOR 2
#define CUTLASS_PATCH 1
#define CUTLASS_MINOR 3
#define CUTLASS_PATCH 0
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
#ifdef __NVCC__
@ -47,9 +48,31 @@
// CUTLASS_DEVICE is an error if not compiling device code
#endif
// CUDA 10.1 introduces the mma instruction
#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA)
#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0
#endif
// CUTLASS assert
#define CUTLASS_ASSERT(x) assert(x)
#include "cutlass/util/performance_tuning.h"
// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
#if defined(__CUDA_ARCH__)
#define CUTLASS_PRAGMA_UNROLL #pragma unroll
#define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1
#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
#define CUTLASS_GEMM_LOOP_HEADER \
asm volatile (".pragma \"nounroll\";\n");
#else
#define CUTLASS_PRAGMA_UNROLL
#define CUTLASS_PRAGMA_NO_UNROLL
#define CUTLASS_GEMM_LOOP_HEADER
#define CUTLASS_GEMM_LOOP
#endif
// A small helper class to dump a type at compile time
// Usage:: DumpType<Class>::Class

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -98,9 +98,9 @@ struct StorageType<1> {
template <typename Element_, int kElements_, size_t kAlignment_ = 16>
struct Fragment : public AlignedStruct<kAlignment_> {
/// Make sure the alignment makes sense wrt the size of elements.
static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
static_assert(int(kAlignment_) == 16 || int(kAlignment_) >= sizeof(Element_), "Alignment is too small");
/// Alignment must be a power of two
static_assert(is_pow2<kAlignment_>::value, "Alignment must be a power of two");
static_assert(is_pow2<int(kAlignment_)>::value, "Alignment must be a power of two");
/// This class.
typedef Fragment<Element_, kElements_> This_;
@ -109,27 +109,31 @@ struct Fragment : public AlignedStruct<kAlignment_> {
/// The number of elements.
static int const kElements = kElements_;
/// Alignment
static int const kAlignment = kAlignment_;
static int const kAlignment = int(kAlignment_);
/// Clear a fragment.
CUTLASS_HOST_DEVICE void clear() {
// Avoid element-wise access for sub 32b element type
if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
ptr[i] = uint64_t(0);
}
} else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
ptr[i] = uint32_t(0);
}
} else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
ptr[i] = uint16_t(0);
}
} else {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kElements; ++i) {
storage[i] = 0;
}
@ -146,7 +150,7 @@ struct Fragment : public AlignedStruct<kAlignment_> {
private:
/// Storage type to use for Elements
typedef typename StorageType<kAlignment_>::Type StorageType;
typedef typename StorageType<int(kAlignment_)>::Type StorageType;
/// Number of elements in the storage
static int const kStorageCount =

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -68,7 +68,6 @@ struct FragmentMultiplyAdd {
FragmentB_ const& b,
FragmentCd_ const& c,
FragmentCd_& d) {
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
d[j] = b[j * kReduction + 0];

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -33,6 +33,8 @@
#include "cutlass/coord.h"
#include "cutlass/util/platform.h"
#include "cutlass/gemm/gemm.h"
namespace cutlass {
namespace gemm {
@ -47,7 +49,8 @@ struct DeviceGemm {
#if !defined(__CUDACC_RTC__)
/// Launch the kernels in order
static __host__ cudaError_t launch(Params const& params) {
Traits::GemmTraits::KernelClass::launch(params.GemmParams);
//Traits::GemmTraits::KernelClass::launch(params.GemmParams);
Gemm<typename Traits::GemmTraits>::launch(params.GemmParams);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
return err;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -73,7 +73,7 @@ struct SplitkPIGemmTraits {
/// The pointer to workspace memory
ScalarAccum *workspace_ptr;
///
int workspace_size;
size_t workspace_size;
/// The Params for the first kernel
typename GemmTraits::Params GemmParams;
/// The Params for the second kernel
@ -112,7 +112,8 @@ struct SplitkPIGemmTraits {
Index ldc_,
ScalarD* d_d_,
Index ldd_,
ScalarAccum *workspace_ptr_) {
ScalarAccum *workspace_ptr_,
Index partitionK_multiple = 1) {
workspace_ptr = workspace_ptr_;
@ -133,7 +134,7 @@ struct SplitkPIGemmTraits {
TensorRef<typename GemmTraits::ScalarC const, 2>(workspace_ptr, problem_size.m()), /*m = ldc, workspace is not transposed and is packed*/
TensorRef<typename GemmTraits::ScalarD, 2>(workspace_ptr, problem_size.m()) /*m = ldd, workspace is not transposed and is packed*/
);
GemmParams.initialize(desc, ReductionTraits::ReductionSize);
GemmParams.initialize(desc, ReductionTraits::ReductionSize, partitionK_multiple);
//call batched reduction (second kernel) param
@ -155,9 +156,12 @@ struct SplitkPIGemmTraits {
// workspace will be used to store D (output) from the first gemm kernel (not D of the entire gemm)
// note typedef typename GemmTraits::ScalarD ScalarAccum;
// workspace of size of M * N * Reduction
int required_workspace_memory_in_byte(){
size_t required_workspace_memory_in_byte(){
assert(problem_size_initialized == true);
workspace_size = problem_size.n() * problem_size.m() * ReductionTraits::ReductionSize * static_cast<int>(sizeof(ScalarAccum));
workspace_size = static_cast<size_t>(problem_size.n()) *
static_cast<size_t>(problem_size.m()) *
static_cast<size_t>(ReductionTraits::ReductionSize) *
sizeof(ScalarAccum);
return workspace_size;
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -29,6 +29,7 @@
#include "cutlass/fragment.h"
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
namespace gemm {
@ -69,8 +70,10 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, float> {
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
d[j * AccumulatorsPerThread::kW + i] = static_cast<ScalarC>(a[i]) * static_cast<ScalarC>(b[j]) + c[j * AccumulatorsPerThread::kW + i];
}
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -33,20 +33,30 @@
#include "cutlass/coord.h"
#include "cutlass/util/platform.h"
#include <cstdio>
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel with launch bounds specified
template <typename Gemm_>
__global__ __launch_bounds__(Gemm_::kThreads)
void gemm_kernel(typename Gemm_::Params params) {
// Declare shared memory.
__shared__ typename Gemm_::SharedStorage shared_storage;
// Dynamic shared memory base pointer
extern __shared__ int GemmSharedStorageBase[];
// Declare pointer to dynamic shared memory.
typename Gemm_::SharedStorage *shared_storage =
reinterpret_cast<typename Gemm_::SharedStorage *>(GemmSharedStorageBase);
// Construct the GEMM object.
Gemm_ gemm(params, shared_storage);
Gemm_ gemm(params, *shared_storage);
// Run GEMM.
gemm.multiply_add();
}
@ -57,11 +67,17 @@ void gemm_kernel(typename Gemm_::Params params) {
template <typename Gemm_>
__global__ /* __launch_bounds__(Gemm_::kThreads) */
void gemm_kernel_nolb(typename Gemm_::Params params) {
// Declare shared memory.
__shared__ typename Gemm_::SharedStorage shared_storage;
// Dynamic shared memory base pointer
extern __shared__ int GemmSharedStorageBase[];
// Declare pointer to dynamic shared memory.
typename Gemm_::SharedStorage *shared_storage =
reinterpret_cast<typename Gemm_::SharedStorage *>(GemmSharedStorageBase);
// Construct the GEMM object.
Gemm_ gemm(params, shared_storage);
Gemm_ gemm(params, *shared_storage);
// Run GEMM.
gemm.multiply_add();
}
@ -72,7 +88,31 @@ void gemm_kernel_nolb(typename Gemm_::Params params) {
template <typename Gemm, bool WithLaunchBounds>
struct Launch {
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
gemm_kernel<Gemm><<< grid, block, 0, stream >>>(params);
int smem_size = int(sizeof(typename Gemm::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(
gemm_kernel<Gemm>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size
);
if (result != cudaSuccess) {
return;
}
result = cudaFuncSetAttribute(
gemm_kernel_nolb<Gemm>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
if (result != cudaSuccess) {
return;
}
}
gemm_kernel<Gemm><<< grid, block, sizeof(typename Gemm::SharedStorage), stream >>>(params);
}
};
@ -82,50 +122,51 @@ struct Launch {
template <typename Gemm>
struct Launch<Gemm, false> {
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
gemm_kernel_nolb<Gemm><<< grid, block, 0, stream >>>(params);
int smem_size = int(sizeof(typename Gemm::SharedStorage));
if (smem_size >= (48 << 10)) {
cudaError_t result = cudaFuncSetAttribute(
gemm_kernel_nolb<Gemm>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size
);
if (result != cudaSuccess) {
return;
}
result = cudaFuncSetAttribute(
gemm_kernel_nolb<Gemm>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
if (result != cudaSuccess) {
// throw exception?
return;
}
}
gemm_kernel_nolb<Gemm><<<
grid,
block,
smem_size,
stream >>>(params);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmTraits_>
template <typename Traits_>
struct Gemm {
/// This class.
typedef Gemm<GemmTraits_> This_;
/// The traits.
typedef GemmTraits_ Traits;
/// The shared storage.
typedef typename Traits::SharedStorage SharedStorage;
/// The scalar for A.
typedef typename Traits::ScalarA ScalarA;
/// The scalar for B.
typedef typename Traits::ScalarB ScalarB;
/// The scalar in the epilogue.
typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
/// The scalar for C.
typedef typename Traits::Epilogue::ScalarC ScalarC;
/// The scalar for D.
typedef typename Traits::Epilogue::ScalarD ScalarD;
/// The index.
typedef typename Traits::Index Index;
/// Define the mainloop iteration size
typedef typename Traits::MultiplyAdd MultiplyAdd;
/// The number of threads.
static int const kThreads = Traits::GemmConfig::kThreads;
// Number of warp-level multiply-accumulate steps executed by each warp.
static Index const kWarpGemmSteps =
Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
typedef Traits_ Traits;
/// Use the params object defined in traits
typedef typename Traits::Params Params;
typedef typename Traits::KernelClass KernelClass;
//
// Static function members
//
@ -137,7 +178,7 @@ struct Gemm {
cudaStream_t stream = cudaStreamDefault) {
// Launch the kernel.
Launch<This_, GemmTraits_::GemmConfig::kLaunchBounds>(
Launch<KernelClass, Traits::GemmConfig::kLaunchBounds>(
params, params.grid, params.block, stream);
return cudaGetLastError();
@ -164,189 +205,6 @@ struct Gemm {
}
#endif
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
: params(params_), shared_storage(shared_storage_) {}
/// Computes a warp-level GEMM on data held in shared memory
template <bool Residue, bool LastIteration>
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
typename Traits::SharedStream& shared_load_stream,
typename MultiplyAdd::Accumulators& accumulators,
Index outer_k) {
// If residue portion and not calculating residue in prolog, update residue predicates now.
if (Residue && outer_k <= Traits::OutputTile::kD) {
global_to_shared_stream.residue(outer_k);
}
// Load data for the next iteration of the main loop (unless it's the last iteration).
if (!LastIteration) {
global_to_shared_stream.copy();
}
CUTLASS_PRAGMA_UNROLL
for (int step = 0; step < kWarpGemmSteps - 1; ++step) {
// Trigger the copy from shared memory for the next A/B values.
shared_load_stream.copy(step + 1);
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(step);
MultiplyAdd multiply_add;
// Do the math on the fragments of the current iteration.
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
shared_load_stream.fragment_b(step),
accumulators,
accumulators);
}
// Make sure the data from shared memory has been entirely consumed.
Traits::shared_load_fence(true);
// Commit the data in shared memory for A/B.
if (!LastIteration) {
global_to_shared_stream.commit();
}
// Make sure the data is in shared memory.
Traits::shared_store_fence(true);
if (!LastIteration) {
// Move to the next stage for the load (if it makes sense).
shared_load_stream.inc_stage();
// Trigger the copy from shared memory for the next loop iteration.
shared_load_stream.copy(0);
}
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(kWarpGemmSteps - 1);
// Do the math on the fragments of the current iteration.
MultiplyAdd multiply_add;
multiply_add.multiply_add(shared_load_stream.fragment_a(kWarpGemmSteps - 1),
shared_load_stream.fragment_b(kWarpGemmSteps - 1),
accumulators,
accumulators);
}
/// Do the GEMM.
CUTLASS_DEVICE void multiply_add() {
// Swizzle the IDs of the block (to enable better cache behavior).
typename Traits::BlockSwizzle block_swizzle;
Coord<3> threadblock_offset =
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::OutputTile>());
// We may want to use shared memory to clear the registers.
typedef typename Traits::ClearAccumulators ClearAccumulators;
// Get the bounds for each thread, it maybe different than problem_size
Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
params.partitionK_range);
// The streams to read A/B from global memory to shared memory.
typename Traits::GlobalLoadStream global_to_shared_stream(
params.global_to_shared_stream,
shared_storage.main_loop.global_to_shared_stream,
shared_storage.main_loop.threadblock_tile.reference(),
bounds,
threadblock_offset);
// update A and B pointer offset based on batch_id and batch_stride_offset
global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
// Create the accumulator clear.
ClearAccumulators clear;
// Deal with residue in prolog.
// global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
// Fetch the fragments for A and B from global memory.
global_to_shared_stream.copy();
// Copy the elements to shared memory (after transformation if needed).
global_to_shared_stream.commit();
// Make sure the data is in shared memory.
Traits::shared_store_fence(false);
// Rollback to the beginning of the first tile (if residue exists).
// global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
// The stream of data from shared memory to fragments.
typename Traits::SharedStream shared_load_stream(
params.shared_stream,
shared_storage.main_loop.threadblock_tile.reference());
// Trigger the copy from shared memory for the 1st stream.
shared_load_stream.copy(0);
// Allocate the accumulators.
typename MultiplyAdd::Accumulators accumulators;
// Clear the accumulators.
clear.clear(accumulators);
// Initial index
// Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
// problem_size[0] might be bigger than bounds[0]
Index outer_k = bounds[0] - Traits::OutputTile::kD;
// Check if we are computing residue in prolog or not.
if (Traits::GemmConfig::kResidueInProlog) {
// Execute all mainloop iterations but the last one.
CUTLASS_GEMM_LOOP
for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
consume_tile<false, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
// Don't load data for the last "residue" portion since we've already computed the residue.
CUTLASS_GEMM_LOOP
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
consume_tile<false, true>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
} else {
// When kResidueSeparate = true, execute all mainloop iterations but the last two without any
// consideration for K-residue or predicate updates. This improves the steady state of some
// kernels.
if (Traits::GemmConfig::kResidueSeparate) {
CUTLASS_GEMM_LOOP
for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
consume_tile<false, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
}
// Execute remaining tiles with K-residue predicate updates enabled.
CUTLASS_GEMM_LOOP
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
consume_tile<true, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
}
// Epilogue.
typedef typename Traits::Epilogue Epilogue;
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
}
//
// Data members
//
/// The params.
Params const& params;
/// The shared storage.
SharedStorage& shared_storage;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -0,0 +1,274 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements a software-pipelined efficient GEMM.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/coord.h"
namespace cutlass {
namespace gemm {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits_>
struct GemmMainloop {
//
// Type definitions
//
/// The traits.
typedef Traits_ Traits;
/// The GEMM mainloop
typedef typename Traits::KernelClass KernelClass;
/// The shared storage.
typedef typename Traits::SharedStorage SharedStorage;
/// The scalar for A.
typedef typename Traits::ScalarA ScalarA;
/// The scalar for B.
typedef typename Traits::ScalarB ScalarB;
/// The scalar in the epilogue.
typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
/// The scalar for C.
typedef typename Traits::Epilogue::ScalarC ScalarC;
/// The scalar for D.
typedef typename Traits::Epilogue::ScalarD ScalarD;
/// The index.
typedef typename Traits::Index Index;
/// Define the mainloop iteration size
typedef typename Traits::MultiplyAdd MultiplyAdd;
/// The number of threads.
static int const kThreads = Traits::GemmConfig::kThreads;
// Number of warp-level multiply-accumulate steps executed by each warp.
static Index const kWarpGemmSteps =
Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
/*
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
*/
/// Use the params object defined in traits
typedef typename Traits::Params Params;
//
// Data members
//
/// The params.
Params const& params;
/// SharedStorage object
SharedStorage& shared_storage;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE GemmMainloop(Params const& params_, SharedStorage& shared_storage_)
: params(params_), shared_storage(shared_storage_) {}
/// Fetches global stream pair
template <bool Residue>
CUTLASS_DEVICE void fetch_global(typename Traits::GlobalLoadStream& global_to_shared_stream,
Index outer_k) {
// If residue portion and not calculating residue in prolog, update residue predicates now.
if (Residue) {
global_to_shared_stream.residue(outer_k);
}
global_to_shared_stream.copy();
}
/// Computes a warp-level GEMM on data held in shared memory
template <bool Residue, bool LastIteration>
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
typename Traits::SharedStream& shared_load_stream,
typename MultiplyAdd::Accumulators& accumulators,
Index outer_k) {
// Whether to load global stream before loading shared stream
const bool kGlobalStreamFirst = (kWarpGemmSteps <= 4);
// Load data for the next iteration of the main loop (unless it's the last iteration).
if (kGlobalStreamFirst && !LastIteration) {
fetch_global<Residue>(global_to_shared_stream, outer_k);
}
CUTLASS_PRAGMA_UNROLL
for (int step = 0; step < kWarpGemmSteps; ++step) {
// Trigger the copy from shared memory for the next A/B values.
shared_load_stream.copy((step + 1) % kWarpGemmSteps);
// Load data for the next iteration of the main loop (unless it's the last iteration).
if (!kGlobalStreamFirst && (step == 0) && !LastIteration) {
fetch_global<Residue>(global_to_shared_stream, outer_k);
}
if (step == kWarpGemmSteps - 2) {
// Make sure the data from shared memory has been entirely consumed.
Traits::shared_load_fence(true);
global_to_shared_stream.commit();
// Make sure the data is in shared memory.
Traits::shared_store_fence(true);
// Move to the next stage for the load (if it makes sense).
shared_load_stream.inc_stage();
}
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(step);
// Do the math on the fragments of the current iteration.
MultiplyAdd multiply_add;
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
shared_load_stream.fragment_b(step),
accumulators,
accumulators);
}
}
/// Do the GEMM.
CUTLASS_DEVICE void multiply_add() {
// Swizzle the IDs of the block (to enable better cache behavior).
typename Traits::BlockSwizzle block_swizzle;
Coord<3> threadblock_offset =
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::OutputTile>());
// We may want to use shared memory to clear the registers.
typedef typename Traits::ClearAccumulators ClearAccumulators;
// Get the bounds for each thread, it maybe different than problem_size
Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
params.partitionK_range);
// The streams to read A/B from global memory to shared memory.
typename Traits::GlobalLoadStream global_to_shared_stream(
params.global_to_shared_stream,
shared_storage.main_loop.global_to_shared_stream,
shared_storage.main_loop.threadblock_tile.reference(),
bounds,
threadblock_offset);
// update A and B pointer offset based on batch_id and batch_stride_offset
global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
// Create the accumulator clear.
ClearAccumulators clear;
// Deal with residue in prolog.
// global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
// Fetch the fragments for A and B from global memory.
global_to_shared_stream.copy();
// Copy the elements to shared memory (after transformation if needed).
global_to_shared_stream.commit();
// Make sure the data is in shared memory.
Traits::shared_store_fence(false);
// Rollback to the beginning of the first tile (if residue exists).
// global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
// The stream of data from shared memory to fragments.
typename Traits::SharedStream shared_load_stream(
params.shared_stream,
shared_storage.main_loop.threadblock_tile.reference());
// Trigger the copy from shared memory for the 1st stream.
shared_load_stream.copy(0);
// Allocate the accumulators.
typename MultiplyAdd::Accumulators accumulators;
// Clear the accumulators.
clear.clear(accumulators);
// Initial index
// Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
// problem_size[0] might be bigger than bounds[0]
Index outer_k = bounds[0] - Traits::OutputTile::kD;
// Check if we are computing residue in prolog or not.
if (Traits::GemmConfig::kResidueInProlog) {
// Execute all mainloop iterations but the last one.
CUTLASS_GEMM_LOOP
for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
CUTLASS_GEMM_LOOP_HEADER
consume_tile<false, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
consume_tile<false, true>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
} else {
// When kResidueSeparate = true, execute all mainloop iterations but the last two without any
// consideration for K-residue or predicate updates. This improves the steady state of some
// kernels.
if (Traits::GemmConfig::kResidueSeparate) {
CUTLASS_GEMM_LOOP
for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
CUTLASS_GEMM_LOOP_HEADER
consume_tile<false, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
}
// Execute remaining tiles with K-residue predicate updates enabled.
CUTLASS_GEMM_LOOP
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
CUTLASS_GEMM_LOOP_HEADER
consume_tile<true, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
}
typedef typename Traits::Epilogue Epilogue;
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -91,8 +91,16 @@ struct SharedLoadStream {
transformer = Transformer();
}
/// Clears the fragment
CUTLASS_DEVICE void clear() {
fetched[0].clear();
fetched[1].clear();
transformed[0].clear();
transformed[1].clear();
}
/// Load the data from shared memory to the fetch fragment.
CUTLASS_DEVICE void copy() {
CUTLASS_DEVICE void copy() {
iterator.load_post_increment(fetched[0]);
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -140,14 +140,19 @@ struct GlobalLoadStreamPair {
/// Trigger the copies from shared memory to registers.
CUTLASS_DEVICE void copy() {
stream_a.copy();
stream_b.copy();
}
/// Commit the data.
CUTLASS_DEVICE void commit() {
stream_a.commit();
stream_b.commit();
}
/// Execute the residue code.
@ -233,6 +238,13 @@ struct SharedStreamPair {
stream_b.commit(step);
}
/// Clears all fragments
CUTLASS_DEVICE
void clear() {
stream_a.clear();
stream_b.clear();
}
/// The fragment A.
CUTLASS_DEVICE
typename StreamA::TransformedFragment const &fragment_a(int step) const {

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -42,7 +42,7 @@
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/gemm/gemm_shared_stream.h"
#include "cutlass/gemm/threadblock_swizzle.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/gemm_mainloop.h"
namespace cutlass {
namespace gemm {
@ -359,7 +359,7 @@ struct GemmTraits {
ClearAccumulators_> This_;
/// The struct that consumes this Traits
typedef typename cutlass::gemm::Gemm<This_> KernelClass;
typedef typename cutlass::gemm::GemmMainloop<This_> KernelClass;
/// The configuration.
typedef GemmConfig_ GemmConfig;
@ -544,16 +544,26 @@ struct GemmTraits {
/// Helper to construct a partitionedK GEMM params
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_) {
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc,
Index partitionK_count_,
Index partitionK_multiple_ = 1 // each partition will be mulitples of partitionK_multiple_
) {
// partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
// the problem_size of each batch is (lastK_size, n, m)
// add more comments here
// the k range for every batch excpet the last one
//assert(partitionK_count_ > 0);
partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
partitionK_range = partitionK_range - (partitionK_range % partitionK_multiple_);
// the k range of the last batch
// int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
assert((partitionK_range % partitionK_multiple_) == 0);
assert(partitionK_range > 0);
assert((lastK_range % partitionK_multiple_) == 0);
assert(lastK_range > 0);
int k_size = lastK_range;
int lda = partitonK_desc.A.stride(0);
int ldb = partitonK_desc.B.stride(0);
@ -641,7 +651,8 @@ struct GemmTraits {
Index ldc,
ScalarD* d_d,
Index ldd,
Index partitionK_count_) {
Index partitionK_count_,
Index partitionK_multiple_ = 1) {
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
GemmCoord(k, n, m, 1),
@ -654,7 +665,7 @@ struct GemmTraits {
);
return this->initialize(desc, partitionK_count_);
return this->initialize(desc, partitionK_count_, partitionK_multiple_);
}
};

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -29,7 +29,6 @@
#pragma once
#include "cutlass/fragment.h"
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
@ -66,6 +65,8 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, half> {
/// Make sure there's an even number of elements in both dimensions.
static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
static_assert(AccumulatorsPerThread::kH >= 2 && AccumulatorsPerThread::kW >= 2,
"HGEMM expects at least 2x2 accmulator tiles per thread.");
/// Ctor.
CUTLASS_DEVICE ThreadMultiplyAdd() {}
@ -84,7 +85,10 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, half> {
// The output.
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
// The offsets in the output fragment.
int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -38,7 +38,7 @@
#include "cutlass/gemm/gemm_traits.h"
#include "cutlass/gemm/hgemm_global_tile.h"
#include "cutlass/gemm/hgemm_multiply_add.h"
#include "cutlass/gemm/hgemm_swizzle.h"
#include "cutlass/layout/thread/transform.h"
namespace cutlass {
namespace gemm {
@ -107,7 +107,8 @@ struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
template <typename Iterator_>
struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
typedef HgemmSwizzle<Iterator_> Transformer;
typedef typename Iterator_::FragmentShape FragmentShape;
typedef cutlass::layout::thread::Transform<FragmentShape, 2, half, cutlass::MatrixLayout::RowMajor, half, cutlass::MatrixLayout::ColumnMajor > Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -122,7 +123,8 @@ struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
template <typename Iterator_>
struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
typedef HgemmSwizzle<Iterator_> Transformer;
typedef typename Iterator_::FragmentShape FragmentShape;
typedef cutlass::layout::thread::Transform<FragmentShape, 2, half, cutlass::MatrixLayout::RowMajor, half, cutlass::MatrixLayout::ColumnMajor > Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -28,8 +28,9 @@
*/
#pragma once
#include "cutlass/fragment.h"
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610))
#include "cutlass/fragment.h"
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
@ -44,6 +45,11 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
typedef Shape<4, 1, 1> InstructionShape;
/// Shape of the thread-level GEMM (K-by-N-by-M)
typedef ThreadGemmShape_ ThreadGemmShape;
/// Thread-level GEMM (N-by-M) must be a multiple of 32.
static_assert((ThreadGemmShape::kH * ThreadGemmShape::kW) % 32 == 0,
"Thread-level GEMM (N-by-M) must be multiple of 32");
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
typedef ThreadGemmShape AccumulatorsPerThread;
/// The number of threads per warp.
@ -72,19 +78,18 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
Accumulators const& c,
Accumulators& d) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)
// The inputs.
int const* a_int = reinterpret_cast<int const*>(&a[0]);
int const* b_int = reinterpret_cast<int const*>(&b[0]);
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
: "=r"(d[j * AccumulatorsPerThread::kW + i])
: "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
}
}
#endif
}
};
@ -92,3 +97,5 @@ struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int>
} // namespace gemm
} // namespace cutlass
#endif // if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 610))

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -60,6 +60,7 @@ struct IgemmSwizzle {
/// Transform a fragment.
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
// Expose src/dst as int arrays.
int const* src_int = reinterpret_cast<int const*>(&src[0]);
int* dst_int = reinterpret_cast<int*>(&dst[0]);

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -39,7 +39,7 @@
#include "cutlass/gemm/igemm_epilogue.h"
#include "cutlass/gemm/igemm_global_tile.h"
#include "cutlass/gemm/igemm_multiply_add.h"
#include "cutlass/gemm/igemm_swizzle.h"
#include "cutlass/layout/thread/transform.h"
#include "cutlass/reshape_tile.h"
namespace cutlass {
@ -90,9 +90,10 @@ struct IgemmConfig : public GemmConfig<
/// kResidueSeparate
false,
/// kResidueInPrologue
false,
true,
/// kLaunchBounds
false> {};
false>
{};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -380,7 +381,8 @@ struct IgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
template <typename Iterator_>
struct IgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
typedef IgemmSwizzle<Iterator_> Transformer;
typedef typename Iterator_::FragmentShape FragmentShape;
typedef cutlass::layout::thread::Transform<FragmentShape, 2, int8_t, cutlass::MatrixLayout::RowMajor, int8_t, cutlass::MatrixLayout::ColumnMajor > Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -395,7 +397,8 @@ struct IgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
template <typename Iterator_>
struct IgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
typedef IgemmSwizzle<Iterator_> Transformer;
typedef typename Iterator_::FragmentShape FragmentShape;
typedef cutlass::layout::thread::Transform<FragmentShape, 2, int8_t, cutlass::MatrixLayout::RowMajor, int8_t, cutlass::MatrixLayout::ColumnMajor > Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,6 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

284
cutlass/gemm/mma_epilogue.h Normal file
View File

@ -0,0 +1,284 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
with
the computed matrix product.
*/
#pragma once
// clang-format off
#include "cutlass/coord.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename EpilogueTraits_>
struct MMAEpilogue {
/// The traits class.
typedef EpilogueTraits_ Traits;
/// The params.
typedef typename Traits::Params Params;
/// The shared storage.
typedef typename Traits::SharedStorage SharedStorage;
/// Defines a tiling of the EpilogueTile over the entire threadblock GEMM tile
typedef typename Traits::Iterations Iterations;
/// The output tile.
typedef typename Traits::OutputTile OutputTile;
/// Accumulators to store in the epilogue
typedef typename Traits::Accumulators Accumulators;
/// A functor to copy a slice of accumulators for a given epilogue iteration
typedef typename Traits::SelectAccumulators SelectAccumulators;
/// The iterator to load source matrix from global memory.
typedef typename Traits::GlobalLoadStreamC GlobalLoadStreamC;
/// The iterator to store the final GEMM computation to global memory.
typedef typename Traits::GlobalStoreStreamD GlobalStoreStreamD;
/// The stream to store matrix product to shared memory
typedef typename Traits::SharedStoreStreamD SharedStoreStreamD;
/// The stream to load the matrix product from shared memory
typedef typename Traits::SharedLoadStreamD SharedLoadStreamD;
/// The functor in charge of the math.
typedef typename Traits::Functor Functor;
/// The scalar type used by the epilogue functor.
typedef typename Functor::Scalar Scalar;
/// The scalar type of the source accumulator matrix.
typedef typename Traits::ScalarC ScalarC;
/// The scalar type of the destination accumulator matrix.
typedef typename Traits::ScalarD ScalarD;
/// The index type.
typedef typename Traits::Index Index;
/// Functor computing the offset from the threadblock origin per iteration of
/// the epilogue.
typedef typename Traits::GlobalOffset GlobalOffset;
///
typedef typename Traits::GlobalDataLayout GlobalDataLayout;
//
// Data members
//
/// The params.
Params const& params;
/// The shared storage.
SharedStorage& shared_storage;
/// The dimensions of the GEMM.
gemm::GemmCoord problem_size;
/// Epilogue functor
Functor functor;
// Functor to select a set of accumulators
SelectAccumulators select_accumulators;
// Functor to compute the global offset relative to the threadblock for each iteration
// of the epilogue.
GlobalOffset global_offset;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE MMAEpilogue(
Params const& params_,
SharedStorage& shared_storage_,
Coord<3> const& _problem_size,
SelectAccumulators _select_accumulators = SelectAccumulators(),
GlobalOffset _global_offset = GlobalOffset()
):
params(params_),
shared_storage(shared_storage_),
problem_size(_problem_size),
functor(params_.functor),
select_accumulators(_select_accumulators),
global_offset(_global_offset) {}
/// Execute the epilogue.
CUTLASS_DEVICE void epilogue(
Accumulators& accumulators,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
int batch_id = 0) {
if (functor.source_required()) {
epilogue_with_or_without_beta<true>(accumulators, threadblock_offset, batch_id);
}
else {
epilogue_with_or_without_beta<false>(accumulators, threadblock_offset, batch_id);
}
}
///
/// Execute the epilogue.
template <bool kSourceRequired>
CUTLASS_DEVICE void epilogue_with_or_without_beta(
Accumulators& accumulators,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
int batch_id = 0) {
/// Global memory mapping function
GlobalDataLayout gmem_map_func;
// Construct shared memory streams
SharedStoreStreamD shared_store_stream(
params.shared_store_stream_d,
shared_storage.reference());
SharedLoadStreamD shared_load_stream(
params.shared_load_stream_d,
shared_storage.reference());
// Map the GEMM problem dimensions into the coordinate system of the output memory
Coord<2> gmem_bounds = gmem_map_func(make_Coord(
problem_size.m(), // GEMM M - rows
problem_size.n())); // GEMM N - columns
Coord<3> gmem_tile_bounds = make_Coord(
problem_size.k(), // GEMM K
gmem_bounds[0], // strided
gmem_bounds[1]); // contiguous
// Iterate over the entire Threadblock tile
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
if (!(h == 0)) {
//continue;
}
// Offset in GEMM coordinates
gemm::GemmCoord offset_in_gemm = threadblock_offset + global_offset(make_Coord(h, w));
Coord<2> offset_in_memory = gmem_map_func(
make_Coord(
offset_in_gemm.m(), // GEMM M - rows
offset_in_gemm.n())); // GEMM N - columns
// Offset in
Coord<3> global_tile_offset = make_Coord(
offset_in_gemm.k(), // GEMM K
offset_in_memory[0], // strided
offset_in_memory[1]); // contiguous
GlobalLoadStreamC global_load_stream(
params.load_stream_c,
gmem_tile_bounds,
global_tile_offset);
GlobalStoreStreamD global_store_stream(
params.store_stream_d,
gmem_tile_bounds,
global_tile_offset);
// update C pointer offset based on batch_id and batch_stride_offset
global_load_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_C);
// update D pointer offset based on batch_id and batch_stride_offset
global_store_stream.iterator.add_pointer_offset(batch_id * params.batch_stride_D);
// Load the C matrix into fragment.
if (kSourceRequired) {
global_load_stream.copy();
}
// Make sure we can write to shared memory.
shared_load_fence();
// Store accumulator tile to shared memory
shared_store_stream.copy(
select_accumulators(accumulators, make_Coord(h, w)));
shared_store_stream.commit();
// Make sure the data is in shared memory.
shared_store_fence();
// Load the accumulators back to registers from shared memory.
shared_load_stream.copy();
shared_load_stream.commit();
// Commit the C matrix fragment
if (kSourceRequired) {
global_load_stream.commit();
}
// Apply epilogue functor
if (kSourceRequired) {
functor.evaluate(shared_load_stream.fragment(),
global_load_stream.fragment(),
global_store_stream.fragment());
}
else {
functor.evaluate(
shared_load_stream.fragment(),
global_store_stream.fragment());
}
global_store_stream.copy();
global_store_stream.commit();
}
}
}
/// The memory fence for shared loads.
CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
/// The memory fence for shared stores.
CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // gemm
} // namespace cutlass
// clang-format on

View File

@ -0,0 +1,360 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements efficient loading of the thread block-level tile from global memory and
storing to shared memory.
*/
#pragma once
// clang-format off
#include "cutlass/convert.h"
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tile_allocation.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory
template <
/// Identifies multiplicand
GemmOperand::Kind Operand,
/// Layout of source matrix in global memory
MatrixLayout::Kind Layout,
/// Iterator for loading threadblock-scoped tiles
typename LoadIterator_,
/// Transformation functor for transforming fragments
typename Transformer_,
/// Iterator for storing threadblock-scoped tiles to shared memory
typename StoreIterator_,
/// Number of stores before iterator wraps - zero indicates no wrapping
int StageCount>
struct MMAGlobalLoadStream {
//
// Type definitions
//
/// Identifies the operand
static GemmOperand::Kind const kOperand = Operand;
/// The layout.
static MatrixLayout::Kind const kLayout = Layout;
/// The load iterator.
typedef LoadIterator_ LoadIterator;
/// The transformer.
typedef Transformer_ Transformer;
/// The store iterator to write to shared memory.
typedef StoreIterator_ StoreIterator;
/// Number of stages
static int const kStageCount = StageCount;
/// Predicate vector
typedef typename LoadIterator::PredicateVector PredicateVector;
/// The fragment that is copied from shared memory.
typedef typename LoadIterator::Fragment FetchedFragment;
/// The fragment that is obtained after the transformation by the transformer.
typedef typename Transformer::OutputFragment TransformedFragment;
/// Make sure the fragments match.
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
"");
/// The output fragment.
typedef TransformedFragment Fragment;
/// Make sure the transformed fragment is the same as the store fragment.
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
"");
/// The scalar type of the iterator.
typedef typename LoadIterator::Scalar Scalar;
/// The pointer.
typedef typename LoadIterator::Pointer Pointer;
/// The index.
typedef typename LoadIterator::Index Index;
/// The index.
typedef typename LoadIterator::LongIndex LongIndex;
/// The tile.
typedef typename LoadIterator::Tile Tile;
/// The params.
struct Params {
/// Helper
static int const kElementsPerLdg = LoadIterator::Tile::kC;
//
// Data members
//
/// The load iterator.
typename LoadIterator::Params load_iterator;
/// Stride within a batch of matrix operands
LongIndex batch_stride;
// Offset to residue.
Index offset_to_residue;
// Offset to residue for the last partition
Index offset_to_residue_last_partition;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params(): batch_stride(0), offset_to_residue(0), offset_to_residue_last_partition(0) {}
/// Constructor
CUTLASS_HOST_DEVICE
Params(
TensorRef<half const, 2> const &ref,
Index _offset_to_residue
):
batch_stride(0),
offset_to_residue(_offset_to_residue),
offset_to_residue_last_partition(0),
load_iterator(
TensorRef<half const, 4>(
ref.data(),
make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1)
)
) {}
/// Initializer
CUTLASS_HOST_DEVICE
int initialize(
TensorRef<half const, 2> const &ref,
LongIndex batch_stride_,
Index offset_to_residue_,
Index offset_to_residue_last_partition_) {
batch_stride = batch_stride_;
offset_to_residue = offset_to_residue_;
offset_to_residue_last_partition = offset_to_residue_last_partition_;
return load_iterator.initialize(
TensorRef<half const, 4>(
ref.data(),
make_Coord(static_cast<int>(batch_stride), ref.stride(0), kElementsPerLdg, 1)
)
);
}
CUTLASS_HOST_DEVICE
int initialize(
TensorRef<half const, 2> const &ref,
Index offset_to_residue_) {
offset_to_residue = offset_to_residue_;
return load_iterator.initialize(
TensorRef<half const, 4>(
ref.data(),
make_Coord(ref.stride(0) * kElementsPerLdg, ref.stride(0), kElementsPerLdg, 1)
)
);
}
CUTLASS_HOST_DEVICE int initialize(Index offset_to_residue_) {
offset_to_residue = offset_to_residue_;
return 0;
}
CUTLASS_DEVICE Index get_offset_to_residue() {
if (blockIdx.z == gridDim.z - 1) { //last partition
return offset_to_residue_last_partition;
}
else {
return offset_to_residue;
}
}
};
/// Empty shared storage
struct SharedStorage {};
/// Shared memory allocation for the tile
typedef TileAllocation<
typename StoreIterator::Scalar,
typename ShapeMul<
typename StoreIterator::OperandShape,
Shape<kStageCount, 1, 1, 1>
>::Shape
> ThreadblockTileStorage;
/// ZipTensorRef to threadblock tiles
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
//
// Data members
//
///! The parameters
Params params;
///! Dimensions of global memory tile
Coord<3> threadblock_offset;
///! Dimensions of multiplicand bounds
Coord<3> multiplicand_bounds;
///! Iterator to load threadblock tiles from global memory
LoadIterator load_iterator;
///! Predicate vector
PredicateVector predicates;
///! The fragment to fetch from shared memory.
FetchedFragment fetched_fragment;
///! Functor to transform fragments after they have been loaded
Transformer transformer;
///! The fragment to convert the data after it has been fetched from shared memory.
TransformedFragment transformed_fragment;
///! Iterator to store threadblock tiles to shared memory
StoreIterator store_iterator;
///! Counter
int stage_index;
//
// Static member functions
//
/// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory
CUTLASS_HOST_DEVICE
static Coord<3> project_coordinate(Coord<3> const &coord, Index d_offset = 0) {
bool const kKstrided =
gemm::GemmMultiplicandTraits<typename LoadIterator::Tile, kOperand, kLayout>::kKstrided;
Coord<3> tile_coord = gemm::ProjectOperand<kOperand, kKstrided>::project(coord);
return make_Coord(
tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
}
//
// Methods
//
/// Constructor
CUTLASS_DEVICE MMAGlobalLoadStream(Params const &_params,
SharedStorage &shared_storage,
ThreadblockTileRef const &threadblock_tile_ref,
Coord<3> const bounds,
Coord<3> const &block)
: params(_params),
threadblock_offset(project_coordinate(block)),
multiplicand_bounds(project_coordinate(bounds, 1)),
load_iterator(params.load_iterator, threadblock_offset),
transformer(),
store_iterator(threadblock_tile_ref.data()),
stage_index(0) {
load_iterator.initialize_predicates(
predicates.begin(), multiplicand_bounds, threadblock_offset);
}
/// Loads the data from global memory
CUTLASS_DEVICE void copy() {
load_iterator.load_post_increment(fetched_fragment, predicates.begin());
}
/// Transform and commit the data to shared memory
CUTLASS_DEVICE void commit() {
transformer.transform(fetched_fragment, transformed_fragment);
store_iterator.store_post_increment(transformed_fragment);
++stage_index;
if (kStageCount && stage_index == kStageCount) {
store_iterator -= kStageCount;
stage_index = 0;
}
}
/// Computes a predicate mask for loads during final threadblock tile load iteration
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
// That's the residue!
Coord<3> _block_offset = threadblock_offset;
if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) {
// K-strided
_block_offset =
make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]);
} else {
// K-contiguous
_block_offset = make_Coord(threadblock_offset[0],
threadblock_offset[1],
multiplicand_bounds[2] - k / LoadIterator::Tile::kC);
}
load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset);
fetched_fragment.clear();
}
/// Move to the residue portion.
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
Index kResidue = k % kTileK;
if (kResidue) {
residue(kResidue);
Index this_offset_residue = params.get_offset_to_residue();
load_iterator.add_pointer_offset(this_offset_residue * load_iterator.stride_advance());
}
}
/// Rollback to the beginning of the first tile
CUTLASS_DEVICE void rollback(void) {
load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, threadblock_offset);
int const kBlock = kOperand == GemmOperand::kA
? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
: (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
Index this_offset_residue = params.get_offset_to_residue();
load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) *
load_iterator.stride_advance());
}
/// Adds a Coord<3> to the underlying global load iterator
CUTLASS_DEVICE MMAGlobalLoadStream &operator+=(Coord<3> const &offset) {
load_iterator += offset;
return *this;
}
/// Adds an offset based on batch stride
CUTLASS_DEVICE MMAGlobalLoadStream &add_batch_offset(int batch_id) {
load_iterator.add_pointer_offset(batch_id * params.batch_stride);
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // gemm
} // namespace cutlass
// clang-format on

View File

@ -0,0 +1,201 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/tile_iterator.h"
#include "cutlass/util/platform.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Iterators used to load multiplicands from global memory specialized for Volta884 access patterns
//
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Iterator for loading data for congruous access patterns
template <GemmOperand::Kind Operand, typename Tile_, int WarpCount, int WarpDelta>
struct MMAThreadblockCongruousLoad {
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = Operand;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout =
(Operand == GemmOperand::kA ? MatrixLayout::kColumnMajor : MatrixLayout::kRowMajor);
/// Shape of thread-block multiplicand
typedef Tile_ Tile;
/// Number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp accumulator tiles along the outer dimension
static int const kWarpDelta = WarpDelta;
/// This implementation is specialized for 128b loads
static int const kAccessSize = 8;
/// Projects the threadblock tile
typedef typename gemm::GemmMultiplicandTraits<Tile_, Operand, kLayout>::Shape OperandShape;
/// Reshapes the threadblock tile by access size
typedef typename ReshapeTile<OperandShape, kAccessSize>::Tile VectorizedShape;
/// Shape of tile
typedef Shape<1, 4, 8> WarpStoreCoverage;
/// Shape of tile loaded by each warp per load operation
typedef Shape<1, 4, 8> WarpLoadShape;
//
// Load iterator
//
///
typedef Shape<1, WarpLoadShape::kH * kWarpCount, WarpLoadShape::kW> Delta;
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
/// Rakes warps along contiguous dimensions and strip-mines strided
/// dimension.
typedef Shape<1,
VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount,
VectorizedShape::kW / WarpStoreCoverage::kW,
1>
Iterations;
/// Functor computing starting offset for each thread
struct ThreadOffset {
__device__ Coord<4> operator()() const {
int warp_id = (threadIdx.x >> 5);
int lane_id = (threadIdx.x & 0x1f);
int lane_k = lane_id / WarpLoadShape::kW;
int lane_outer = lane_id % WarpLoadShape::kW;
Coord<4> offset = make_Coord(0, warp_id * WarpLoadShape::kH + lane_k, lane_outer, 0);
return offset;
}
};
/// Source tile traits
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> LoadTileTraits;
/// Load iterator
typedef TileLoadIterator<LoadTileTraits, half, IteratorAdvance::kH> Iterator;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Iterator for loading data for congruous access patterns
template <GemmOperand::Kind Operand, typename Tile_, int WarpCount, int WarpDelta>
struct MMAThreadblockCrosswiseLoad {
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = Operand;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout =
(Operand == GemmOperand::kA ? MatrixLayout::kRowMajor : MatrixLayout::kColumnMajor);
/// Shape of thread-block multiplicand
typedef Tile_ Tile;
/// Number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp accumulator tiles along the outer dimension
static int const kWarpDelta = WarpDelta;
/// This implementation is specialized for 128b loads
static int const kAccessSize = 8;
/// Projects the threadblock tile
typedef typename gemm::GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
/// Reshapes the threadblock tile by access size
typedef typename ReshapeTile<OperandShape, kAccessSize>::Tile VectorizedShape;
/// Shape of tile
typedef Shape<1, 8, 4> WarpStoreCoverage;
/// Shape of tile loaded by each warp per load operation
typedef Shape<1, 8, 4> WarpLoadShape;
//
// Load iterator
//
///
typedef Shape<1, WarpLoadShape::kH, WarpLoadShape::kW> Delta;
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
/// Rakes warps along contiguous dimensions and strip-mines strided
/// dimension.
typedef Shape<1,
VectorizedShape::kH / WarpStoreCoverage::kH / WarpCount,
VectorizedShape::kW / WarpStoreCoverage::kW,
1>
Iterations;
/// Functor computing starting offset for each thread
struct ThreadOffset {
__device__ Coord<4> operator()() const {
int warp_id = (threadIdx.x >> 5);
int lane_id = (threadIdx.x & 0x1f);
int lane_k = lane_id % WarpLoadShape::kW;
int lane_outer = lane_id / WarpLoadShape::kW;
Coord<4> offset =
make_Coord(0, warp_id * Iterations::kH * WarpLoadShape::kH + lane_outer, lane_k, 0);
return offset;
}
};
/// Source tile traits
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> LoadTileTraits;
/// Load iterator
typedef TileLoadIterator<LoadTileTraits, half, IteratorAdvance::kW> Iterator;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // gemm
} // namespace cutlass

View File

@ -0,0 +1,155 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements efficient loading of the thread block-level tile from global memory and
storing to shared memory.
*/
#pragma once
#include "cutlass/convert.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate
template <
/// The load iterator.
typename Iterator_,
/// The transformer to be applied after the data has been copied from shared memory.
typename Transformer_ = Copy<typename Iterator_::Fragment>,
/// Number of increments before iterator wraps - zero indicates no wrapping
int StageCount = 1>
struct MMASharedLoadStream {
/// The load iterator.
typedef Iterator_ Iterator;
/// The transformer.
typedef Transformer_ Transformer;
/// Number of increments before iterator wraps - zero indicates no wrapping
static int const kStageCount = StageCount;
/// The fragment that is copied from shared memory.
typedef typename Iterator::Fragment FetchedFragment;
/// The fragment that is obtained after the transformation by the transformer.
typedef typename Transformer::OutputFragment TransformedFragment;
/// Make sure the fragments match.
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
"");
/// The output fragment.
typedef TransformedFragment Fragment;
/// Element type
typedef typename Iterator::Scalar Scalar;
/// Reference type to a tensor
typedef TensorRef<half, 4> TensorRef;
/// Parameters passed from host
struct Params {};
//
// Data members
//
/// Iterator for loading fragments for warp-level matrix multiply-accumulate
Iterator iterator;
/// Fetched fragment
FetchedFragment fetched[2];
/// The transformer.
Transformer transformer;
/// Transformed fragment
TransformedFragment transformed[2];
/// Counts the number of stages
int stage_index;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE MMASharedLoadStream() : stage_index(0) {}
/// Ctor.
CUTLASS_DEVICE MMASharedLoadStream(
Params const &_params,
TensorRef const &ref,
Coord<4> warp_offset = make_Coord(0, 0, 0, 0)
):
iterator(ref.data(), warp_offset), stage_index(0) {
}
/// Load the data from shared memory to the fetch fragment.
CUTLASS_DEVICE void copy(int step) {
iterator.load(
fetched[step % 2],
make_Coord(step + stage_index * Iterator::VectorizedShape::kD, 0, 0, 0)
);
}
/// Commit the data.
CUTLASS_DEVICE void commit(int step) {
transformer.transform(fetched[step % 2], transformed[step % 2]);
}
///
CUTLASS_DEVICE void clear() {
fetched[0].clear();
fetched[1].clear();
transformed[0].clear();
transformed[1].clear();
}
/// Gets the transformed fragment
CUTLASS_DEVICE
TransformedFragment &fragment(int step) { return transformed[step % 2]; }
/// Gets the transformed fragment
CUTLASS_DEVICE
TransformedFragment const &fragment(int step) const { return transformed[step % 2]; }
/// Increment the stage.
CUTLASS_DEVICE void inc_stage() {
++stage_index;
if (kStageCount && stage_index == StageCount) {
stage_index = 0;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // gemm
} // namespace cutlass

View File

@ -1,6 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -73,16 +73,27 @@ struct ThreadMultiplyAdd {
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
if(kLayout_ == MatrixLayout::kColumnMajor) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i];
}
}
}
else {
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < AccumulatorsPerThread::kW; ++i) {
CUTLASS_PRAGMA_UNROLL
for(int j = 0; j < AccumulatorsPerThread::kH; ++j) {
d[i * AccumulatorsPerThread::kH + j] = a[i] * b[j] + c[i * AccumulatorsPerThread::kH + j];
}
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -0,0 +1,348 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
with the computed matrix product.
*/
#pragma once
// clang-format off
#include "cutlass/zip_fragment.h"
#include "cutlass/zip_tile_iterator.h"
#include "cutlass/util/complex.h"
#include "cutlass/gemm/volta884_gemm_epilogue_traits.h"
#include "cutlass/gemm/split_complex_linear_scaling.h"
#include "cutlass/util/pair.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Enables treating the accumulators selection as one object
template <typename First_, typename Second_>
struct ZipSelectAccumulators {
/// Underlying selection function
typedef First_ First;
typedef Second_ Second;
/// Accumulators
typedef ZipFragment<
typename First::Accumulators,
typename Second::Accumulators> Accumulators;
/// Fragment
typedef ZipFragment<
typename First::Fragment,
typename Second::Fragment> Fragment;
//
// Data members
//
/// Selects the accumulators for the first part
First first;
/// Selects the accumulators for the second
Second second;
//
// Methods
//
/// Default ctor
CUTLASS_DEVICE
ZipSelectAccumulators() { }
/// Basic constructor
CUTLASS_DEVICE
ZipSelectAccumulators(First const &_first, Second const &_second): first(_first), second(_second) { }
/// Selects accumulators for a given iteration of the epilogue
CUTLASS_DEVICE
Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
return make_ZipFragment(first(accum.first, idx), second(accum.second, idx));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines epilogue traits for complex-valued mma.sync GEMM
template <
typename GemmConfig_,
typename EpilogueFunctor_ = SplitComplexLinearScaling<typename GemmConfig_::MultiplyAdd::ScalarC>,
typename Index_ = int>
struct Volta884ComplexGemmEpilogueTraits {
/// GEMM configuration
typedef GemmConfig_ GemmConfig;
/// Epilogue functor
typedef EpilogueFunctor_ Functor;
/// Global memory mapping function
typedef MatrixLayout::ColumnMajor GlobalDataLayout;
/// Index type
typedef Index_ Index;
/// Long index used for offsets
typedef long long LongIndex;
/// Defines epilogue traits for real-valued Volta884 GEMM epilogue
typedef typename Volta884GemmEpilogueTraitsHelper<
GemmConfig,
Functor,
typename GemmConfig::MultiplyAdd::RealMultiplyAdd,
Index>::EpilogueTraits RealEpilogueTraits;
/// The output tile.
typedef typename RealEpilogueTraits::OutputTile OutputTile;
/// The warp-level GEMM tile
typedef typename RealEpilogueTraits::WarpGemmTile WarpGemmTile;
/// Tiling of warp accumulator elements
typedef typename RealEpilogueTraits::WarpGemmTile WarpDelta;
/// Multiply-add operation
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
/// The accumulators fragment type.
typedef typename MultiplyAdd::Accumulators Accumulators;
/// Selects a subset of accumulators for a given epilogue iteration
typedef ZipSelectAccumulators<
typename RealEpilogueTraits::SelectAccumulators,
typename RealEpilogueTraits::SelectAccumulators> SelectAccumulators;
/// The iterator to load source matrix from global memory.
typedef cutlass::PredicatedTileLoadStream<
ZipTileIterator<
typename RealEpilogueTraits::GlobalLoadStreamC::Iterator,
typename RealEpilogueTraits::GlobalLoadStreamC::Iterator
>,
typename RealEpilogueTraits::GlobalLoadStreamC::PredicateFunctor,
ZipConvert<
typename RealEpilogueTraits::GlobalLoadStreamC::Transformer,
typename RealEpilogueTraits::GlobalLoadStreamC::Transformer
>
> GlobalLoadStreamC;
/// The iterator to store the final GEMM computation to global memory.
typedef cutlass::PredicatedTileStoreStream<
ZipTileIterator<
typename RealEpilogueTraits::GlobalStoreStreamD::Iterator,
typename RealEpilogueTraits::GlobalStoreStreamD::Iterator
>,
typename RealEpilogueTraits::GlobalStoreStreamD::PredicateFunctor,
ZipConvert<
typename RealEpilogueTraits::GlobalStoreStreamD::Transformer,
typename RealEpilogueTraits::GlobalStoreStreamD::Transformer
>
> GlobalStoreStreamD;
/// The stream to store matrix product to shared memory
typedef cutlass::TileStoreStream<
ZipTileIterator<
typename RealEpilogueTraits::SharedStoreStreamD::Iterator,
typename RealEpilogueTraits::SharedStoreStreamD::Iterator
>,
ZipConvert<
typename RealEpilogueTraits::SharedStoreStreamD::Transformer,
typename RealEpilogueTraits::SharedStoreStreamD::Transformer
>
> SharedStoreStreamD;
/// The stream to load the matrix product from shared memory
typedef cutlass::TileLoadStream<
ZipTileIterator<
typename RealEpilogueTraits::SharedLoadStreamD::Iterator,
typename RealEpilogueTraits::SharedLoadStreamD::Iterator
>,
ZipConvert<
typename RealEpilogueTraits::SharedLoadStreamD::Transformer,
typename RealEpilogueTraits::SharedLoadStreamD::Transformer
>
> SharedLoadStreamD;
/// The scalar type of the source accumulator matrix.
typedef typename RealEpilogueTraits::ScalarC ScalarC;
/// The scalar type of the destination accumulator matrix.
typedef typename RealEpilogueTraits::ScalarD ScalarD;
//
// Dependent types
//
/// Cover an entire warp-level tile
typedef typename RealEpilogueTraits::Iterations Iterations;
/// Parameters structure initialized on the host
struct Params {
/// The params for the C iterator.
typename GlobalLoadStreamC::Params load_stream_c;
/// The params for the D global iterator.
typename GlobalStoreStreamD::Params store_stream_d;
/// Epilogue functor params
typename Functor::Params functor;
/// The params for the D shared store iterator.
typename SharedStoreStreamD::Params shared_store_stream_d;
/// The params for the D shared load stream.
typename SharedLoadStreamD::Params shared_load_stream_d;
/// Stride for C
platform::Pair<LongIndex, LongIndex> batch_stride_C;
/// Stride for D
platform::Pair<LongIndex, LongIndex> batch_stride_D;
//
// Methods
//
/// Default constructor
CUTLASS_HOST_DEVICE
Params() {
batch_stride_C.first = 0;
batch_stride_C.second = 0;
batch_stride_D.first = 0;
batch_stride_D.second = 0;
}
/// Setup the params.
CUTLASS_HOST_DEVICE int initialize(
platform::complex<typename Functor::Scalar> alpha,
platform::complex<typename Functor::Scalar> beta,
ScalarC const* real_C,
Index real_ldc,
ScalarC const* imag_C,
Index imag_ldc,
ScalarD* real_D,
Index real_ldd,
ScalarD* imag_D,
Index imag_ldd) {
int result = functor.initialize(alpha, beta);
if (result) {
return result;
}
// Setup the params for the global memory iterator for C.
result = load_stream_c.iterator.first.initialize(
real_C, real_ldc, real_ldc, 1);
if (result) {
return result;
}
result = load_stream_c.iterator.second.initialize(
imag_C, imag_ldc, imag_ldc, 1);
if (result) {
return result;
}
// Setup the params for the global memory iterator for D.
result = store_stream_d.iterator.first.initialize(
real_D, real_ldd, real_ldd, 1);
if (result) {
return result;
}
result = store_stream_d.iterator.second.initialize(
imag_D, imag_ldd, imag_ldd, 1);
if (result) {
return result;
}
return result;
}
/// Setup the params.
CUTLASS_HOST_DEVICE int initialize(
platform::complex<typename Functor::Scalar> alpha,
platform::complex<typename Functor::Scalar> beta,
ScalarC const* real_C,
Index real_ldc,
LongIndex stride_C_real,
ScalarC const* imag_C,
Index imag_ldc,
LongIndex stride_C_imag,
ScalarD* real_D,
Index real_ldd,
LongIndex stride_D_real,
ScalarD* imag_D,
Index imag_ldd,
LongIndex stride_D_imag) {
batch_stride_C.first = stride_C_real;
batch_stride_C.second = stride_C_imag;
batch_stride_D.first = stride_D_real;
batch_stride_D.second = stride_D_imag;
return initialize(alpha, beta, real_C, real_ldc, imag_C, imag_ldc, real_D, real_ldd, imag_D, imag_ldd);
}
};
/// Shared memory buffer used by epilogue
typedef ZipTileAllocation<
typename RealEpilogueTraits::SharedStorage,
typename RealEpilogueTraits::SharedStorage> SharedStorage;
/// Functor computing the offset from the threadblock origin per iteration of
/// the epilogue.
typedef typename RealEpilogueTraits::GlobalOffset GlobalOffset;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
namespace platform {
/// Here's a helpful arithmetic operator
CUTLASS_HOST_DEVICE
Pair<long long, long long> operator*(int s, Pair<long long, long long> _pair) {
return Pair<long long, long long>(s * _pair.first, s * _pair.second);
}
}
} // namespace cutlass
// clang-format on

View File

@ -0,0 +1,558 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties for complex-valued GEMM targeting Volta's mma.sync
instruction.
At present, it expects split complex representation in global memory in which the real part and
imaginary parts of a complex-valued matrices are disjoint (a structure of arrays). This is in
contrast with an interleaved complex representation which is an array of structures.
*/
#pragma once
// clang-format off
#include "cutlass/gemm/clear_accumulators.h"
#include "cutlass/gemm/gemm_config.h"
#include "cutlass/gemm/gemm_stream_pair.h"
#include "cutlass/gemm/threadblock_swizzle.h"
#include "cutlass/gemm/linear_scaling.h"
#include "cutlass/kernel_launch.h"
#include "cutlass/tensor_ref_collection.h"
#include "cutlass/gemm/gemm_desc.h"
#include "cutlass/gemm/volta884_multiplicand.h"
#include "cutlass/gemm/mma_shared_stream.h"
#include "cutlass/gemm/volta884_gemm_traits.h"
#include "cutlass/gemm/volta884_complex_multiply_add.h"
#include "cutlass/gemm/volta884_complex_global_stream.h"
#include "cutlass/gemm/volta884_complex_shared_stream.h"
#include "cutlass/gemm/volta884_complex_gemm_epilogue_traits.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines configuration for Volta884 GEMM
template <
/// The layout for A.
MatrixLayout::Kind LayoutA,
/// Indicates matrix transform on multiplicand A
MatrixTransform::Kind TransformA,
/// The layout for B.
MatrixLayout::Kind LayoutB,
/// Indicates matrix transform on multiplicand B
MatrixTransform::Kind TransformB,
/// The tile size for the GEMM KxNxM.
typename OutputTile_,
/// Tile size for warp-level GEMM (K-by-N-by-M)
typename WarpGemmShape_,
/// The accumulator type.
typename Accumulator_,
/// The source matrix type type.
typename ScalarC_,
/// The destination matrix type
typename ScalarD_,
/// Number of stages in shared memory
int StageCount,
/// Enables or disables launch bounds
bool LaunchBounds>
struct Volta884ComplexGemmConfig : public GemmConfig<
/// The scalar type for A.
half,
/// The scalar type for B.
half,
/// The scalar type for C.
ScalarC_,
/// The scalar type for D.
ScalarD_,
/// The threadblock tile size
OutputTile_,
/// The functor to do the math in the main loop.
Volta884ComplexMultiplyAdd<WarpGemmShape_,
LayoutA,
TransformA,
half,
LayoutB,
TransformB,
half,
Accumulator_>,
/// The number of scalars per LDG for A.
8,
/// The number of scalars per STS for A.
8,
/// The number of scalars per LDS for A.
8,
/// The number of scalars per LDG for B.
8,
/// The number of scalars per STS for B.
8,
/// The number of scalars per LDS for B.
8,
/// The number of scalars per LDG for C and STG for D.
16 / int(sizeof(ScalarD_)),
/// The number of scalars per STS for D.
16 / int(sizeof(ScalarD_)),
/// The number of scalars per LDS for D.
16 / int(sizeof(ScalarD_)),
/// The number of stages in shared memory.
StageCount,
/// If true, separate mainloop is instantiated
true,
/// If true, compute residue in prolog
false,
/// Launch bounds not used
LaunchBounds> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines components of Volta884 GEMM
template <
/// The layout for A.
MatrixLayout::Kind LayoutA,
/// Indicates matrix transform on multiplicand A
MatrixTransform::Kind TransformA,
/// The layout for B.
MatrixLayout::Kind LayoutB,
/// Indicates matrix transform on multiplicand B
MatrixTransform::Kind TransformB,
/// The tile size for the GEMM KxNxM.
typename OutputTile_,
/// Tile size for warp-level GEMM (K-by-N-by-M)
typename WarpGemmShape_,
/// The accumulator type.
typename Accumulator_,
/// The input matrix type type.
typename ScalarC_,
/// The output matrix type type.
typename ScalarD_,
/// Number of buffers in shared memory to use
int StageCount,
/// The functor to do the math in the epilogue.
typename EpilogueFunctor_ = SplitComplexLinearScaling<Accumulator_>,
/// Enables or disables launch bounds
bool LaunchBounds = false
>
struct Volta884ComplexGemmTraits {
/// This is insane.
typedef Volta884ComplexGemmTraits<
LayoutA,
TransformA,
LayoutB,
TransformB,
OutputTile_,
WarpGemmShape_,
Accumulator_,
ScalarC_,
ScalarD_,
StageCount,
EpilogueFunctor_,
LaunchBounds> This;
/// The actual device-side GEMM
typedef GemmMainloop<This> KernelClass;
/// Layout of multiplicand A matrix
static MatrixLayout::Kind const kLayoutA = LayoutA;
/// If true, A operand is conjugated
static MatrixTransform::Kind const kTransformA = TransformA;
/// Layout of multiplicand B matrix
static MatrixLayout::Kind const kLayoutB = LayoutB;
/// If true, B operand is conjugated
static MatrixTransform::Kind const kTransformB = TransformB;
/// Dimensions of threadblock tile (concept Shape)
typedef OutputTile_ OutputTile;
/// Shape of warp-level accumulators
typedef WarpGemmShape_ WarpGemmShape;
/// Multiplicand A scalar type
typedef half ScalarA;
/// Multiplicand B scalar type
typedef half ScalarB;
/// Data type of internal accumulator
typedef Accumulator_ Accumulator;
/// Data type of input accumulator matrix operand
typedef ScalarC_ ScalarC;
/// Data type of output accumulator matrix operand
typedef ScalarD_ ScalarD;
/// Shape of individual mma.sync instruction
typedef Shape<4, 16, 16> InstructionShape;
/// Tile size for an individual warp-level multiply-add
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
/// Defines properties about GEMM needed by host code
typedef Volta884ComplexGemmConfig<
kLayoutA,
kTransformA,
kLayoutB,
kTransformB,
OutputTile,
WarpGemmShape,
Accumulator,
ScalarC,
ScalarD,
StageCount,
LaunchBounds>
GemmConfig;
//
// Derived types
//
/// Index type
typedef int Index;
/// Long index type
typedef long long LongIndex;
/// Partitioning of threadblock into warps
typedef typename ShapeDiv<OutputTile, WarpGemmShape>::Shape WarpDelta;
/// Number of warps per threadblock
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
/// Defines iterators for A matrix
typedef Volta884Multiplicand<GemmOperand::kA, kLayoutA, OutputTile, WarpTile, kWarpCount, WarpDelta>
MultiplicandA;
/// Defines iterators for B matrix
typedef Volta884Multiplicand<GemmOperand::kB, kLayoutB, OutputTile, WarpTile, kWarpCount, WarpDelta>
MultiplicandB;
//
// GemmTraits mandatory type definitions
//
/// Maps hardware threadblocks to logical partitions of the GEMM
typedef IdentityBlockSwizzle BlockSwizzle;
/// Clears accumulators
typedef ClearAccumulators<ScalarC> ClearAccumulators;
/// Loads multiplicands from global memory
typedef GlobalLoadStreamPair<
Volta884ComplexGlobalLoadStream<GemmOperand::kA,
kLayoutA,
typename MultiplicandA::LoadIterator,
Copy<typename MultiplicandA::LoadIterator::Fragment>,
typename MultiplicandA::StoreIterator,
StageCount>,
Volta884ComplexGlobalLoadStream<GemmOperand::kB,
kLayoutB,
typename MultiplicandB::LoadIterator,
Copy<typename MultiplicandB::LoadIterator::Fragment>,
typename MultiplicandB::StoreIterator,
StageCount>,
GemmConfig::kResidueInProlog >
GlobalLoadStream;
/// Memory needed to store the threadblock-scoped GEMM tile
typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
/// Shared memory storage for mainloop phase
union MainLoopStorage {
/// Stores the threadblock tile
ThreadblockTileStorage threadblock_tile;
/// Storage for GEMM global stream
typename GlobalLoadStream::SharedStorage global_to_shared_stream;
};
/// Loads multiplicands from shared memory
typedef SharedStreamPair<
Volta884ComplexSharedLoadStream<typename MultiplicandA::WarpLoadIterator,
Copy<typename MultiplicandA::WarpLoadIterator::Fragment>,
StageCount>,
Volta884ComplexSharedLoadStream<typename MultiplicandB::WarpLoadIterator,
Copy<typename MultiplicandB::WarpLoadIterator::Fragment>,
StageCount> >
SharedStream;
// Multiply-add object specialized for Volta mma.sync
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
#if 0
/// Naive epilogue for updating the output matrix
typedef Volta884ComplexNaiveEpilogue<ScalarC,
typename MultiplicandA::WarpDelta,
typename MultiplyAdd::Iterations>
Epilogue;
#else
/// Efficient epilogue
typedef MMAEpilogue<
Volta884ComplexGemmEpilogueTraits<GemmConfig, EpilogueFunctor_>
> Epilogue;
#endif
/// Tensor reference to A multiplicand
typedef ZipTensorRef<
TensorRef<ScalarA, 2>,
TensorRef<ScalarA, 2>
> TensorRefA;
/// Tensor reference to B multiplicand
typedef ZipTensorRef<
TensorRef<ScalarB, 2>,
TensorRef<ScalarB, 2>
> TensorRefB;
/// Tensor reference to C multiplicand
typedef ZipTensorRef<
TensorRef<ScalarC, 2>,
TensorRef<ScalarC, 2>
> TensorRefC;
/// Tensor reference to D multiplicand
typedef ZipTensorRef<
TensorRef<ScalarD, 2>,
TensorRef<ScalarD, 2>
> TensorRefD;
/// gemm::ProblemDesc<>
typedef GemmDesc<
TensorRefA,
TensorRefB,
TensorRefC,
TensorRefD,
float
> GemmDesc;
/// Parameters structure
struct Params : public KernelLaunchConfiguration {
/// The dimensions of the GEMM.
GemmCoord problem_size;
/// PartitionK_range
int partitionK_range;
/// The params for the global load stream
typename GlobalLoadStream::Params global_to_shared_stream;
/// The params for the shared load stream
typename SharedStream::Params shared_stream;
/// The params for the epilogue.
typename Epilogue::Params epilogue;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
/// Initialize the Params struct
CUTLASS_HOST_DEVICE int initialize(
Index m,
Index n,
Index k,
platform::complex<typename Epilogue::Scalar> alpha,
ScalarA const* real_A,
Index real_lda,
ScalarA const* imag_A,
Index imag_lda,
ScalarB const* real_B,
Index real_ldb,
ScalarB const* imag_B,
Index imag_ldb,
platform::complex<typename Epilogue::Scalar> beta,
ScalarC const* real_C,
Index real_ldc,
ScalarC const* imag_C,
Index imag_ldc,
ScalarD* real_D,
Index real_ldd,
ScalarD* imag_D,
Index imag_ldd) {
problem_size = make_Coord(k, n, m, 1);
partitionK_range = problem_size.k();
// Compute grid dimensions
BlockSwizzle block_swizzle;
this->block = dim3(GemmConfig::kThreads);
this->grid = block_swizzle.get_grid_layout(
problem_size,
make_Coord_from_shape<OutputTile>());
// Initialize global load streams
global_to_shared_stream.stream_a.initialize(
make_ZipTensorRef(
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_A, real_lda), 0),
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_A, imag_lda), 0)
),
0
);
global_to_shared_stream.stream_b.initialize(
make_ZipTensorRef(
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_B, real_ldb), 0),
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_B, imag_ldb), 0)
),
0
);
return epilogue.initialize(
alpha,
beta,
real_C,
real_ldc,
imag_C,
imag_ldc,
real_D,
real_ldd,
imag_D,
imag_ldd
);
}
/// Initialize the Params struct
CUTLASS_HOST_DEVICE int initialize(
Index m,
Index n,
Index k,
platform::complex<typename Epilogue::Scalar> alpha,
ScalarA const* real_A,
Index real_lda,
LongIndex batch_stride_A_real,
ScalarA const* imag_A,
Index imag_lda,
LongIndex batch_stride_A_imag,
ScalarB const* real_B,
Index real_ldb,
LongIndex batch_stride_B_real,
ScalarB const* imag_B,
Index imag_ldb,
LongIndex batch_stride_B_imag,
platform::complex<typename Epilogue::Scalar> beta,
ScalarC const* real_C,
Index real_ldc,
LongIndex batch_stride_C_real,
ScalarC const* imag_C,
Index imag_ldc,
LongIndex batch_stride_C_imag,
ScalarD* real_D,
Index real_ldd,
LongIndex batch_stride_D_real,
ScalarD* imag_D,
Index imag_ldd,
LongIndex batch_stride_D_imag,
int batch_count) {
problem_size = make_Coord(k, n, m, batch_count);
partitionK_range = problem_size.k();
// Compute grid dimensions
BlockSwizzle block_swizzle;
this->block = dim3(GemmConfig::kThreads);
this->grid = block_swizzle.get_grid_layout(
problem_size,
make_Coord_from_shape<OutputTile>());
// Initialize global load streams
global_to_shared_stream.stream_a.initialize(
make_ZipTensorRef(
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_A, real_lda), batch_stride_A_real),
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_A, imag_lda), batch_stride_A_imag)
),
0
);
global_to_shared_stream.stream_b.initialize(
make_ZipTensorRef(
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(real_B, real_ldb), batch_stride_B_real),
TensorRefBatchStrided<half const, 2>(TensorRef<half const, 2>(imag_B, imag_ldb), batch_stride_B_imag)
),
0
);
return epilogue.initialize(
alpha,
beta,
real_C,
real_ldc,
batch_stride_C_real,
imag_C,
imag_ldc,
batch_stride_C_imag,
real_D,
real_ldd,
batch_stride_D_real,
imag_D,
imag_ldd,
batch_stride_D_imag
);
}
};
/// Shared memory storage
union SharedStorage {
/// Storage required during mainloop phase
MainLoopStorage main_loop;
/// Shared storage needed for epilogue
typename Epilogue::SharedStorage epilogue;
};
/// The memory fence for shared loads.
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
if (StageCount < 2) {
__syncthreads();
}
}
/// The memory fence for shared stores.
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
// clang-format on

View File

@ -0,0 +1,315 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements efficient loading of the thread block-level tile from global memory and
storing
to shared memory.
*/
#pragma once
// clang-format off
#include "cutlass/convert.h"
#include "cutlass/zip_tile_iterator.h"
#include "cutlass/zip_tensor_ref.h"
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/util/pair.h"
#include "cutlass/gemm/mma_global_stream.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
///! Stream adapter for loading threadblock-scoped GEMM tiles and storing to shared memory
template <
/// Identifies multiplicand
GemmOperand::Kind Operand,
/// Layout of source matrix in global memory
MatrixLayout::Kind Layout,
/// Iterator for loading threadblock-scoped tiles
typename LoadIterator_,
/// Transformation functor for transforming fragments
typename Transformer_,
/// Iterator for storing threadblock-scoped tiles to shared memory
typename StoreIterator_,
/// Number of stores before iterator wraps - zero indicates no wrapping
int StageCount>
struct Volta884ComplexGlobalLoadStream {
//
// Type definitions
//
/// Identifies the operand
static GemmOperand::Kind const kOperand = Operand;
/// The layout.
static MatrixLayout::Kind const kLayout = Layout;
/// Load-store stream for real-valued matrices
typedef MMAGlobalLoadStream<Operand, Layout, LoadIterator_, Transformer_, StoreIterator_, StageCount> RealLoadStoreStream;
/// Loads a pair of real-valued fragments
typedef ZipTileIterator<LoadIterator_, LoadIterator_> LoadIterator;
/// Zips a pair of transformers
typedef ZipConvert<Transformer_, Transformer_> Transformer;
/// Stores a pair of real-valued ragments
typedef ZipTileIterator<StoreIterator_, StoreIterator_> StoreIterator;
/// Number of stages
static int const kStageCount = StageCount;
/// Predicate vector
typedef typename RealLoadStoreStream::PredicateVector PredicateVector;
/// The fragment that is copied from shared memory.
typedef typename LoadIterator::Fragment FetchedFragment;
/// The fragment that is obtained after the transformation by the transformer.
typedef typename Transformer::OutputFragment TransformedFragment;
/// Make sure the fragments match.
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
"");
/// The output fragment.
typedef TransformedFragment Fragment;
/// Make sure the transformed fragment is the same as the store fragment.
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
"");
/// Index type
typedef typename RealLoadStoreStream::Index Index;
/// Long index type
typedef typename RealLoadStoreStream::LongIndex LongIndex;
/// The params.
struct Params {
//
// Type definitions
//
/// Matrix reference
typedef ZipTensorRef<
TensorRefBatchStrided<half const, 2>,
TensorRefBatchStrided<half const, 2> > SourceTensorRef;
/// Helper
static int const kElementsPerLdg = LoadIterator::First::Tile::kC;
//
// Data members
//
/// Source tensor reference
platform::Pair<LongIndex, LongIndex> batch_stride;
// The load iterator.
typename LoadIterator::Params load_iterator;
// Offset to residue.
Index offset_to_residue;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
///
CUTLASS_HOST_DEVICE
Params(SourceTensorRef const &ref, Index _offset_to_residue) {
initialize(ref, _offset_to_residue);
}
CUTLASS_HOST_DEVICE
int initialize(SourceTensorRef const &ref, Index _offset_to_residue) {
batch_stride.first = ref.first.tensor_stride;
batch_stride.second = ref.second.tensor_stride;
offset_to_residue = _offset_to_residue;
load_iterator.first.initialize(
TensorRef<half const, 4>(
ref.first.at().data(),
make_Coord(ref.first.at().stride(0) * kElementsPerLdg, ref.first.at().stride(0), kElementsPerLdg)
)
);
load_iterator.second.initialize(
TensorRef<half const, 4>(
ref.second.at().data(),
make_Coord(ref.second.at().stride(0) * kElementsPerLdg, ref.second.at().stride(0), kElementsPerLdg)
)
);
return 0;
}
};
/// Empty shared storage
struct SharedStorage {};
/// Shared memory allocation for the tile
typedef TileAllocation<
typename RealLoadStoreStream::StoreIterator::Scalar,
typename ShapeMul<
typename RealLoadStoreStream::StoreIterator::OperandShape,
Shape<kStageCount, 1, 1, 1>
>::Shape
> RealThreadblockTileStorage;
/// Threadblock tile allocation
typedef ZipTileAllocation<
RealThreadblockTileStorage,
RealThreadblockTileStorage
> ThreadblockTileStorage;
/// Reference to ThreadblockTileStorage
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
//
// Data members
//
///! The parameters
Params params;
///! Dimensions of global memory tile
Coord<3> threadblock_offset;
///! Multiplicand bounds
Coord<3> multiplicand_bounds;
///! Iterator to load threadblock tiles from global memory
LoadIterator load_iterator;
///! Predicate vector
PredicateVector predicates;
///! The fragment to fetch from shared memory.
FetchedFragment fetched_fragment;
///! Functor to transform fragments after they have been loaded
Transformer transformer;
///! The fragment to convert the data after it has been fetched from shared memory.
TransformedFragment transformed_fragment;
///! Iterator to store threadblock tiles to shared memory
StoreIterator store_iterator;
///! Counter
int stage_index;
//
// Methods
//
/// Constructor
CUTLASS_DEVICE Volta884ComplexGlobalLoadStream(Params const &_params,
SharedStorage &shared_storage,
ThreadblockTileRef const &threadblock_tile_ref,
Coord<3> const bounds,
Coord<3> const &block)
: params(_params),
threadblock_offset(RealLoadStoreStream::project_coordinate(block)),
multiplicand_bounds(RealLoadStoreStream::project_coordinate(bounds, 1)),
load_iterator(params.load_iterator, threadblock_offset),
transformer(),
store_iterator(threadblock_tile_ref),
stage_index(0) {
// initialize predicates used to guard loads
load_iterator.initialize_predicates(
predicates.begin(), multiplicand_bounds, threadblock_offset);
}
/// Loads the data from global memory
CUTLASS_DEVICE void copy() {
load_iterator.load_post_increment(fetched_fragment, predicates.begin());
}
/// Transform and commit the data to shared memory
CUTLASS_DEVICE void commit() {
transformer.transform(fetched_fragment, transformed_fragment);
store_iterator.store_post_increment(transformed_fragment);
++stage_index;
if (kStageCount && stage_index == kStageCount) {
store_iterator -= kStageCount;
stage_index = 0;
}
}
/// Computes a predicate mask for loads during final threadblock tile load iteration
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
// That's the residue!
Coord<3> _block_offset = threadblock_offset;
if (kOperand == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor) {
// K-strided
_block_offset =
make_Coord(threadblock_offset[0], multiplicand_bounds[1] - k, threadblock_offset[2]);
} else {
// K-contiguous
_block_offset = make_Coord(threadblock_offset[0],
threadblock_offset[1],
multiplicand_bounds[2] - k / LoadIterator::First::Tile::kC);
}
load_iterator.initialize_predicates(predicates.begin(), multiplicand_bounds, _block_offset);
fetched_fragment.clear();
}
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {}
CUTLASS_DEVICE void rollback() {}
/// Adds a Coord<3> to the underlying global load iterator
CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &operator+=(Coord<3> const &offset) {
load_iterator += offset;
return *this;
}
/// Adds an offset based on batch stride
CUTLASS_DEVICE Volta884ComplexGlobalLoadStream &add_batch_offset(int batch_id) {
load_iterator.first.add_pointer_offset(params.batch_stride.first * batch_id);
load_iterator.second.add_pointer_offset(params.batch_stride.second * batch_id);
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
// clang-format on

View File

@ -0,0 +1,319 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements warp-level multiply-accumulate operations using Volta's mma.sync instruction
for complex-valued data types.
*/
#pragma once
#include "cutlass/util/complex.h"
#include "cutlass/zip_fragment.h"
#include "cutlass/gemm/volta884_multiply_add.h"
#include "cutlass/zip_fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Shape of a warp-level GEMM (K-by-N-by-M)
typename WarpGemmShape_,
/// Layout of multiplicand A
MatrixLayout::Kind LayoutA,
/// Indicates matrix transform on multiplicand A
MatrixTransform::Kind TransformA,
/// Data type of multiplicand A
typename ScalarA_,
/// Layout of multiplicand B
MatrixLayout::Kind LayoutB,
/// Indicates matrix transform on multiplicand B
MatrixTransform::Kind TransformB,
/// Data type of multiplicand B
typename ScalarB_,
/// Data type of accumulators
typename ScalarC_,
/// If true, A operand is conjugated
bool ConjugateA = false,
/// If true, B operand is conjugated
bool ConjugateB = false,
/// If true, infinite results are saturated to +-MAX_FLOAT
bool SatFinite = false>
struct Volta884ComplexMultiplyAdd {
//
// Constant and type definitions
//
/// Shape of a warp-level GEMM (K-by-N-by-M)
typedef WarpGemmShape_ WarpGemmShape;
/// Shape of a warp-level GEMM (K-by-N-by-M)
typedef WarpGemmShape_ AccumulatorsPerWarp;
/// Most of the Volta884 code assumes interleaved 32x32 tiles
typedef Shape<4, 32, 32> InterleavedTileShape;
/// Shape of an individual warp-wide mma.sync instruction
typedef Shape<4, 16, 16> InstructionShape;
/// Shape of a warp-level matrix multiply operation
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
/// Verify WarpTile is a multiple of fundamental 32x32 interleaved tile
static_assert(!(WarpTile::kH % InterleavedTileShape::kH) &&
!(WarpTile::kW % InterleavedTileShape::kW) && WarpTile::kD == 4,
"WarpTile must be a multiple of InterleavedTileShape.");
/// Layout of A multiplicand
static MatrixLayout::Kind const kLayoutA = LayoutA;
/// Indicates matrix transform on multiplicand B
static MatrixTransform::Kind const kTransformA = TransformA;
/// Layout of B multiplicand
static MatrixLayout::Kind const kLayoutB = LayoutB;
/// Indicates matrix transform on multiplicand B
static MatrixTransform::Kind const kTransformB = TransformB;
/// The type for A.
typedef ScalarA_ ScalarA;
/// The type for B.
typedef ScalarB_ ScalarB;
/// The type for C and D.
typedef ScalarC_ ScalarC;
/// If true, infinite results are saturated to +-MAX_FLOAT
static bool const kSatFinite = SatFinite;
/// Hard-coded comptue type supported on Volta
static arch::ComputeType::Kind const kComputeType = arch::ComputeType::kDefault;
/// Underlying matrix multiply-add operator
typedef Volta884MultiplyAdd<WarpGemmShape,
kLayoutA,
ScalarA,
kLayoutB,
ScalarB,
ScalarC>
RealMultiplyAdd;
/// Fragment definition for A multiplicand
typedef ZipFragment<typename RealMultiplyAdd::FragmentA, typename RealMultiplyAdd::FragmentA>
FragmentA;
/// Fragment definition for B multiplicand
typedef ZipFragment<typename RealMultiplyAdd::FragmentB, typename RealMultiplyAdd::FragmentB>
FragmentB;
/// Fragment definition for accumulators
typedef ZipFragment<typename RealMultiplyAdd::Accumulators,
typename RealMultiplyAdd::Accumulators>
Accumulators;
/// Number of mma.sync operations performed. See Volta884MultiplyAdd::Iterations for details.
typedef typename RealMultiplyAdd::Iterations Iterations;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE Volta884ComplexMultiplyAdd() {}
/// Multiply : d = a*b.
CUTLASS_DEVICE void multiply_add(FragmentA const& A,
FragmentB const& B,
Accumulators const& C,
Accumulators& D) {
RealMultiplyAdd op;
// complex-valued multiply-add
op.multiply_add(A.first, B.first, C.first, D.first);
op.multiply_add(A.first, B.second, C.second, D.second, kTransformB == MatrixTransform::kConjugate);
op.multiply_add(A.second, B.first, C.second, D.second, kTransformA == MatrixTransform::kConjugate);
op.multiply_add(A.second, B.second, C.first, D.first,
!((kTransformA == MatrixTransform::kConjugate) ^ (kTransformB == MatrixTransform::kConjugate)));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Complex-valued epilogue
template <typename Accumulator, typename WarpDelta, typename Iterations>
struct Volta884ComplexNaiveEpilogue {
/// Accumulator data type
typedef Accumulator ScalarC;
/// Output accumulator type
typedef Accumulator ScalarD;
/// BLAS Scalar type
typedef Accumulator Scalar;
/// Real-valued epilogue
typedef Volta884NaiveEpilogue<Accumulator, WarpDelta, Iterations> RealEpilogue;
/// Params object
struct Params {
/// Parameters for the real-valued part
typename RealEpilogue::Params real;
/// Parameters for the imaginary-valued part
typename RealEpilogue::Params imag;
//
// Methods
//
/// Default constructor
CUTLASS_HOST_DEVICE Params() {}
/// Constructs from params object
CUTLASS_HOST_DEVICE Params(typename RealEpilogue::Params const& _real,
typename RealEpilogue::Params const& _imag)
: real(_real), imag(_imag) {}
/// Construct from pointers
CUTLASS_HOST_DEVICE Params(ScalarC* _real, int _ldr, ScalarC* _imag, int _ldi)
: real(_real, _ldr), imag(_imag, _ldi) {}
/// Construct from pointers
CUTLASS_HOST_DEVICE Params(
platform::complex<Scalar> const &alpha,
platform::complex<Scalar> const &beta,
ScalarC const *real_C,
int real_ldc,
ScalarC const *imag_C,
int imag_ldc,
ScalarD *real_D,
int real_ldd,
ScalarD *imag_D,
int imag_ldd
):
real(real_D, real_ldd, alpha.real(), beta.real()),
imag(imag_D, imag_ldd, alpha.real(), beta.real()) { }
/// Initializer method
CUTLASS_HOST_DEVICE
int initialize(
platform::complex<Scalar> const &alpha,
platform::complex<Scalar> const &beta,
ScalarC const *real_C,
int real_ldc,
ScalarC const *imag_C,
int imag_ldc,
ScalarD *real_D,
int real_ldd,
ScalarD *imag_D,
int imag_ldd
) {
real = typename RealEpilogue::Params(real_D, real_ldd, alpha.real(), beta.real());
imag = typename RealEpilogue::Params(imag_D, imag_ldd, alpha.real(), beta.real());
return 0;
}
};
/// Shared stoarge
struct SharedStorage {};
/// Accumulator fragment definition
typedef ZipFragment<
typename RealEpilogue::Accumulators,
typename RealEpilogue::Accumulators> Accumulators;
//
// Data members
//
/// Epilogue for real part
RealEpilogue real;
/// Epilogue for imaginary part
RealEpilogue imag;
//
// Methods
//
/// Constructs a complex-valued epilogue
CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(
Params const& _params, Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
: real(_params.real, _problem_size), imag(_params.imag, _problem_size) {}
/// Constructs a complex-valued epilogue
CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(ScalarC* _real,
int _ldr,
ScalarC* _imag,
int _ldi,
Coord<3> const& _problem_size = make_Coord(1024,
1024,
1024))
: real(_real, _ldr, _problem_size), imag(_imag, _ldi, _problem_size) {}
/// Constructs a complex-valued epilogue
CUTLASS_DEVICE Volta884ComplexNaiveEpilogue(Params const& _params,
SharedStorage& shared_storage,
Coord<3> const& _problem_size = make_Coord(1024,
1024,
1024))
: real(_params.real, _problem_size), imag(_params.imag, _problem_size) {}
/// Sets accumulators to zero
CUTLASS_DEVICE void clear(Accumulators& C) {
C.first.clear();
C.second.clear();
}
/// Naive load operation for debugging
CUTLASS_DEVICE void load(Accumulators& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
real.load(C.first, threadblock_offset);
imag.load(C.second, threadblock_offset);
}
/// Naive store operation for debugging
CUTLASS_DEVICE void store(Accumulators const& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
real.store(C.first, threadblock_offset);
imag.store(C.second, threadblock_offset);
}
/// CUTLASS Epilogue interface
CUTLASS_DEVICE void epilogue(Accumulators const& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0),
int batch_id = 0) {
real.store(C.first, threadblock_offset);
imag.store(C.second, threadblock_offset);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,152 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements efficient loading of the thread block-level tile from global memory and
storing to shared memory.
*/
#pragma once
#include "cutlass/convert.h"
#include "cutlass/zip_fragment.h"
#include "cutlass/zip_tensor_ref.h"
#include "cutlass/zip_tile_iterator.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Stream from shared memory to fragments for warp-level matrix multiply-accumulate
template <
/// The load iterator.
typename Iterator_,
/// The transformer to be applied after the data has been copied from shared memory.
typename Transformer_ = Copy<typename Iterator_::Fragment>,
/// Number of increments before iterator wraps - zero indicates no wrapping
int StageCount = 1>
struct Volta884ComplexSharedLoadStream {
/// The load iterator.
typedef Iterator_ RealIterator;
/// Zips two real-valued iterators together
typedef ZipTileIterator<RealIterator, RealIterator> Iterator;
/// The transformer.
typedef Transformer_ RealTransformer;
/// Zips two transfoerms
typedef ZipConvert<RealTransformer, RealTransformer> Transformer;
/// Number of increments before iterator wraps - zero indicates no wrapping
static int const kStageCount = StageCount;
/// The fragment that is copied from shared memory.
typedef typename Iterator::Fragment FetchedFragment;
/// The fragment that is obtained after the transformation by the transformer.
typedef typename Transformer::OutputFragment TransformedFragment;
/// Make sure the fragments match.
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
"");
/// The output fragment.
typedef TransformedFragment Fragment;
/// Reference type
typedef ZipTensorRef<
TensorRef<half, 4>,
TensorRef<half, 4>
> TensorRef;
/// Parameters passed from host
struct Params { };
//
// Data members
//
/// Iterator for loading fragments for warp-level matrix multiply-accumulate
Iterator iterator;
/// Fetched fragment
FetchedFragment fetched[2];
/// The transformer.
Transformer transformer;
/// Transformed fragment
TransformedFragment transformed[2];
/// Counts the number of stages
int stage_index;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE Volta884ComplexSharedLoadStream() : stage_index(0) {}
/// Ctor.
CUTLASS_DEVICE Volta884ComplexSharedLoadStream(Params const &_params,
TensorRef const &ref)
: iterator(ref), stage_index(0) {}
/// Load the data from shared memory to the fetch fragment.
CUTLASS_DEVICE void copy(int step) {
iterator.load(fetched[step % 2],
make_Coord(step + stage_index * Iterator::First::VectorizedShape::kD, 0, 0, 0));
}
/// Commit the data.
CUTLASS_DEVICE void commit(int step) {
transformer.transform(fetched[step % 2], transformed[step % 2]);
}
/// Gets the transformed fragment
CUTLASS_DEVICE
TransformedFragment &fragment(int step) { return transformed[step % 2]; }
/// Gets the transformed fragment
CUTLASS_DEVICE
TransformedFragment const &fragment(int step) const { return transformed[step % 2]; }
/// Increment the stage.
CUTLASS_DEVICE void inc_stage() {
++stage_index;
if (kStageCount && stage_index == StageCount) {
stage_index = 0;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,771 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
with the computed matrix product.
*/
#pragma once
// clang-format off
#include "cutlass/tile_stream.h"
#include "cutlass/tile_allocation.h"
#include "cutlass/gemm/mma_shared_stream.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Abstraction to select accumulators from an accumulator tile for each iteration fo the epilogue
template <typename WarpGemmShape, typename WarpDelta, typename Scalar>
struct Volta884SelectAccumulators;
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Selects accumulators from Volta mma.sync.F32 layout
template <typename WarpGemmShape_, typename WarpDelta_>
struct Volta884SelectAccumulators<WarpGemmShape_, WarpDelta_, float> {
/// Shape of the warp-level matrix multiply operation
typedef WarpGemmShape_ WarpGemmShape;
/// Describes tiling of warp elements
typedef WarpDelta_ WarpDelta;
/// Data type of scalar
typedef float Scalar;
//
// Derived types and constants
//
/// (Actual) number of accumulators held by each individual thread
static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize;
/// Accumulators fragment
typedef Fragment<Scalar, kAccumulatorsPerThread> Accumulators;
/// Number of warps
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
/// Interleaved mma.sync shape
typedef Shape<4, 32, 32> MmaTileShape;
/// Hard-coded for FP32 layouts
typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 4> Elements;
/// Number of elements
static int const kElements = ShapeCount<Elements>::kCount;
/// Slice of accumulators
typedef Fragment<Scalar, kElements> Fragment;
//
// Methods
//
/// Selects accumulators for a given iteration of the epilogue
CUTLASS_DEVICE
Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
Fragment frag;
static int const kAccumPerOp = 8;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Elements::kH; ++j) {
// selects the 32x32 tile
Coord<2> tile_32x32 = make_Coord(idx[0] / 8, j);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Elements::kW; ++i) {
Coord<2> mma_op = make_Coord(((idx[0] >> 1) & 1), i / 2);
int element = ((i & 1) << 1) | (idx[0] & 1) | (idx[0] & 4);
int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]);
frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx];
}
}
return frag;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Selects accumulators from Volta mma.sync.F16 layout
template <typename WarpGemmShape_, typename WarpDelta_>
struct Volta884SelectAccumulators<WarpGemmShape_, WarpDelta_, half> {
/// Shape of the warp-level matrix multiply operation
typedef WarpGemmShape_ WarpGemmShape;
/// Describes tiling of warp elements
typedef WarpDelta_ WarpDelta;
/// Data type of accumulator elements
typedef half Scalar;
//
// Derived types and constants
//
/// (Actual) number of accumulators held by each individual thread
static int const kAccumulatorsPerThread = (WarpGemmShape::kH * WarpGemmShape::kW) / kWarpSize;
/// Accumulators fragment
typedef Fragment<Scalar, kAccumulatorsPerThread> Accumulators;
/// Number of warps
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
/// Interleaved mma.sync shape
typedef Shape<4, 32, 32> MmaTileShape;
/// Hard-coded for FP16 layouts
typedef Shape<1, WarpGemmShape::kW / MmaTileShape::kW, 2> Elements;
/// Number of elements
static int const kElements = ShapeCount<Elements>::kCount;
/// Slice of accumulators
typedef Fragment<Scalar, kElements> Fragment;
//
// Methods
//
/// Selects accumulators for a given iteration of the epilogue
CUTLASS_DEVICE
Fragment operator()(Accumulators const &accum, Coord<2> const &idx) const {
Fragment frag;
static int const kAccumPerOp = 8;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Elements::kH; ++j) {
// selects the 32x32 tile
Coord<2> tile_32x32 = make_Coord(idx[0] / 16, j);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Elements::kW; ++i) {
Coord<2> mma_op = make_Coord(((idx[0] >> 2) & 1), i & 1);
int element = (idx[0] & 3) | ((idx[0] >> 1) & 4);
int mma_op_idx = mma_op[1] + mma_op[0] * 2 + 4 * (tile_32x32[1] + 2 * tile_32x32[0]);
frag[i + j * Elements::kW] = accum[element + kAccumPerOp * mma_op_idx];
}
}
return frag;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
//
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The warp-level GEMM tile
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// Size of vector to load or store
int AccessSize,
/// The accumulators fragment type - implies accumulator layout
typename Accumulators_>
struct Volta884EpilogueGlobalTileTraits;
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Global tile traits specialized for Volta mma.sync.F32 layout
template <
/// The warp-level GEMM tile
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// Size of vector to load or store
int AccessSize>
struct Volta884EpilogueGlobalTileTraits<WarpGemmTile_, WarpDelta_, AccessSize, float> {
/// Shape of warp-scoped GEMM tile
typedef WarpGemmTile_ WarpGemmTile;
/// Structure of MMA
typedef WarpDelta_ WarpDelta;
/// Access size of input/output elements
static int const kAccessSize = AccessSize;
/// Scalar type of accumulators - used to imply accumulator layout, not the data
typedef float Accumulators;
/// Strides for immediate offset computation
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
//typedef Shape<2, 2, 1, 1> Iterations;
/// Hard-coded pitch between Volta mma.sync Quad Pair tiles
static int const kMmaQuadPairWidth = 16;
/// Hard-coded pitch between warp tiles
static int const kInterleavedTileWidth = 32;
/// Number of actual threads
static int const kThreadCount = (WarpDelta::kH * WarpDelta::kW) * kWarpSize;
/// Shape of the tile
typedef Shape<2 * WarpDelta::kH, 2, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
/// Number of iterations
typedef Shape<2 * WarpDelta::kH,
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
1> Iterations;
/// Delta between accesses
typedef Shape<kMmaQuadPairWidth, 2, WarpDelta::kW * kWarpSize, 1> Delta;
/// Number of warps in threadblock
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
/// Custom thread-offset function
struct ThreadOffset {
CUTLASS_DEVICE
Coord<4> operator()() {
int tid = threadIdx.x;
int residual_w = (tid / (Tile::kW));
int offset_w = (tid % (Tile::kW));
int offset_h = (residual_w % Tile::kH);
int offset_d = (residual_w / Tile::kH);
Coord<4> offset = make_Coord(offset_d * Delta::kD, offset_h * Delta::kH, offset_w, 0);
return offset;
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Global tile traits specialized for Volta mma.sync.F16 layout
template <
/// The warp-level GEMM tile
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// Size of vector to load or store
int AccessSize>
struct Volta884EpilogueGlobalTileTraits<WarpGemmTile_, WarpDelta_, AccessSize, half> {
/// Shape of warp-scoped GEMM tile
typedef WarpGemmTile_ WarpGemmTile;
/// Structure of MMA tiles
typedef WarpDelta_ WarpDelta;
/// Access size of input/output elements
static int const kAccessSize = AccessSize;
/// Scalar type of accumulators - used to imply accumulator layout, not the data
typedef half Accumulators;
/// Hard-coded pitch between Volta mma.sync Quad Pair tiles
static int const kMmaQuadPairWidth = 16;
/// Hard-coded pitch between warp tiles
static int const kInterleavedTileWidth = 32;
/// Number of participating threads
static int const kThreadCount = kWarpSize * WarpDelta::kH * WarpDelta::kW;
/// Shape of the tile
typedef Shape<1, 2 * WarpDelta::kH, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
/// Strides for immediate offset computation
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
/// Number of iterations
typedef Shape<
1,
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
1> Iterations;
/// Delta between thread-level accesses
typedef typename platform::conditional<
kThreadCount >= Tile::kW,
Shape<1, kMmaQuadPairWidth * (kThreadCount / Tile::kW), 1, 1>,
Shape<1, kMmaQuadPairWidth, kThreadCount, 1>
>::type Delta;
/// Number of warps in threadblock
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
/// Custom thread-offset function
struct ThreadOffset {
CUTLASS_DEVICE
Coord<4> operator()() {
int tid = threadIdx.x;
int residual_w = (tid / (Tile::kW));
int offset_w = (tid % (Tile::kW));
int offset_h = (residual_w % Tile::kH);
int offset_d = (residual_w / Tile::kH);
Coord<4> offset = make_Coord(offset_d * Delta::kD, offset_h * kMmaQuadPairWidth, offset_w, 0);
return offset;
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Global offset functor for Volta884 epilogues
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename WarpDelta, typename AccumulatorType>
struct Volta884EpilogueGlobalOffset;
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Functor computing the offset from the threadblock origin per iteration of
/// the epilogue. Specialized for Volta mma.sync.F32
template <typename WarpDelta>
struct Volta884EpilogueGlobalOffset<WarpDelta, float> {
/// mma.sync instructions are arranged as spatially overlapping 32x32 tiles
typedef Shape<4, 32, 32> MmaTileShape;
CUTLASS_DEVICE
Coord<3> operator()(Coord<2> const &iteration) const {
int h = iteration[0];
// C++ needs a better way to express bit swizzling
int h_offset = ((h & 1) | ((h & 2) << 1) | (((h & 4) >> 2) * 8) |
(((h & 8) >> 3) * WarpDelta::kH * MmaTileShape::kH));
return make_Coord(0, h_offset, iteration[1] * MmaTileShape::kW * WarpDelta::kW);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Functor computing the offset from the threadblock origin per iteration of
/// the epilogue. Specialized for Volta mma.sync.F16
template <typename WarpDelta>
struct Volta884EpilogueGlobalOffset<WarpDelta, half> {
/// mma.sync instructions are arranged as spatially overlapping 32x32 tiles
typedef Shape<4, 32, 32> MmaTileShape;
CUTLASS_DEVICE
Coord<3> operator()(Coord<2> const &iteration) const {
int h = iteration[0];
// C++ needs a better way to express bit swizzling
int h_offset = (h & 15) | (h & 16) * 2 * WarpDelta::kH;
Coord<3> offset = make_Coord(0, h_offset, iteration[1] * MmaTileShape::kW * WarpDelta::kW);
return offset;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Epilogue traits for Volta884 epilogue
//
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue traits for Volta884 GEMMs
template <
/// The threadblock GEMM tile
typename OutputTile_,
/// The warp-level GEMM tile
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// The accumulators fragment type.
typename Accumulators_,
/// Selects a slice of accumulators
typename SelectAccumulators_,
/// The iterator to load source matrix from global memory.
typename GlobalLoadStreamC_,
/// The iterator to store the final GEMM computation to global memory.
typename GlobalStoreStreamD_,
/// The stream to store matrix product to shared memory
typename SharedStoreStreamD_,
/// The stream to load the matrix product from shared memory
typename SharedLoadStreamD_,
/// The functor computing an element-wise operation on the matrix product
typename Functor_,
/// Global memory mapping function
typename GlobalDataLayout_ = MatrixLayout::ColumnMajor,
/// The index.
typename Index_ = int>
struct Volta884EpilogueTraits {
/// The output tile.
typedef OutputTile_ OutputTile;
/// The warp-level GEMM tile
typedef WarpGemmTile_ WarpGemmTile;
/// Tiling of warp accumulator elements
typedef WarpDelta_ WarpDelta;
/// The accumulators fragment type.
typedef Accumulators_ Accumulators;
/// Selects a subset of accumulators for a given epilogue iteration
typedef SelectAccumulators_ SelectAccumulators;
/// The iterator to load source matrix from global memory.
typedef GlobalLoadStreamC_ GlobalLoadStreamC;
/// The iterator to store the final GEMM computation to global memory.
typedef GlobalStoreStreamD_ GlobalStoreStreamD;
/// The stream to store matrix product to shared memory
typedef SharedStoreStreamD_ SharedStoreStreamD;
/// The stream to load the matrix product from shared memory
typedef SharedLoadStreamD_ SharedLoadStreamD;
/// The functor computing an element-wise operation on the matrix product
typedef Functor_ Functor;
/// Global memory mapping function
typedef GlobalDataLayout_ GlobalDataLayout;
/// The index.
typedef Index_ Index;
/// The scalar type of the source accumulator matrix.
typedef typename GlobalLoadStreamC::Iterator::Scalar ScalarC;
/// The scalar type of the destination accumulator matrix.
typedef typename GlobalStoreStreamD::Iterator::Scalar ScalarD;
//
// Dependent types
//
static bool const kFp32Arrangement = sizeof(typename SelectAccumulators::Scalar) == 4;
/// Skew elements
static int const kSkew = 2;
/// Number of columns of accumulators stored/loaded depends on the accumulator arrangement
static int const kColumnsPerWarp = (kFp32Arrangement ? 4 : 2);
/// mma.sync instructions are arranged as spatially overlapping 32x32 tiles
typedef Shape<4, 32, 32> MmaTileShape;
/// Cover an entire warp-level tile
typedef Shape<1,
WarpGemmTile::kH / kColumnsPerWarp, // iterates over 32x32 accumulator tiles along N dimension
1, // iterates over 32x32 accumulator tiles along M dimension
1>
Iterations;
/// Skew is needed to reduce bank conflicts to SMEM - this shape depends on accumulator layout
typedef Shape<1,
WarpDelta::kH * kColumnsPerWarp, // multiple columns in the gemm N dimension
WarpDelta::kW * WarpGemmTile::kW + kSkew, // rows in the gemm M dimension
1
> EpilogueTileAllocation;
/// Parameters structure initialized on the host
struct Params {
/// The params for the C iterator.
typename GlobalLoadStreamC::Params load_stream_c;
/// The params for the D global iterator.
typename GlobalStoreStreamD::Params store_stream_d;
/// Epilogue functor params
typename Functor::Params functor;
/// The params for the D shared store iterator.
typename SharedStoreStreamD::Params shared_store_stream_d;
/// The params for the D shared load stream.
typename SharedLoadStreamD::Params shared_load_stream_d;
///
long long int batch_stride_C;
///
long long int batch_stride_D;
//
// Methods
//
/// Default constructor
CUTLASS_HOST_DEVICE
Params() {}
/// Helper constructor taking pointer, stride for source and destination matrices and functor
/// params
CUTLASS_HOST_DEVICE
Params(ScalarD *ptr_D,
int ldd,
ScalarC const *ptr_C,
int ldc,
typename Functor::Params _functor = Functor::Params())
: load_stream_c(), store_stream_d(), functor(_functor) {}
/// Setup the params.
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
batch_stride_C = desc.batch_stride_C;
batch_stride_D = desc.batch_stride_D;
// The parameters for the functor.
int error_code = functor.initialize(desc);
if (error_code) {
return error_code;
}
// Setup the params for the global memory iterator for C.
error_code = load_stream_c.iterator.initialize(
desc.C.data(), desc.C.leading_dim(), desc.C.leading_dim(), 1
);
if (error_code) {
return error_code;
}
// Setup the params for the global memory iterator for D.
return store_stream_d.iterator.initialize(
desc.D.data(), desc.D.leading_dim(), desc.D.leading_dim(), 1
);
}
};
/// Shared memory buffer used by epilogue
typedef TileAllocation<
typename SharedStoreStreamD::Iterator::Scalar,
EpilogueTileAllocation> SharedStorage;
/// Functor computing the offset from the threadblock origin per iteration of
/// the epilogue.
typedef Volta884EpilogueGlobalOffset<WarpDelta, typename SelectAccumulators::Scalar> GlobalOffset;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Volta884 Epilogue helper
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename TileTraits, typename AccumulatorType>
struct Volta884EpiloguePredicateFunctor;
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Functor specialized for the predicate arrangement in the Volta884 epilogue
template <typename TileTraits>
struct Volta884EpiloguePredicateFunctor<TileTraits, float> {
/// Dimensions of the bounding volume
Coord<3> bounds;
/// Constructs a predicate functor given the bounds of a tensor
CUTLASS_HOST_DEVICE
Volta884EpiloguePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {}
/// Computes the predicate given the logical position of an access
CUTLASS_HOST_DEVICE
bool operator()(Coord<3> const &iteration, Coord<3> const &offset) const {
return
(iteration[0] * TileTraits::Delta::kD + iteration[1] * TileTraits::Delta::kH +
offset[1] < bounds[1]) &&
(iteration[2] * TileTraits::Delta::kW + offset[2] < bounds[2]);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Functor specialized for the predicate arrangement in the Volta884 epilogue
template <typename TileTraits>
struct Volta884EpiloguePredicateFunctor<TileTraits, half> {
/// Dimensions of the bounding volume
Coord<3> bounds;
/// Constructs a predicate functor given the bounds of a tensor
CUTLASS_HOST_DEVICE
Volta884EpiloguePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {}
/// Computes the predicate given the logical position of an access
CUTLASS_HOST_DEVICE
bool operator()(Coord<3> const &iteration, Coord<3> const &offset) const {
return iteration[1] * TileTraits::Delta::kH + offset[1] < bounds[1] &&
iteration[2] * TileTraits::Delta::kW + offset[2] < bounds[2];
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Volta884 Epilogue helper
//
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to define the traits for a Volta884 Epilogue
template <
typename GemmConfig_,
typename EpilogueFunctor_,
typename MultiplyAdd_ = typename GemmConfig_::MultiplyAdd,
typename Index_ = int>
struct Volta884GemmEpilogueTraitsHelper {
/// Configuration object defining GEMM properties
typedef GemmConfig_ GemmConfig;
/// Warp-level tile
typedef typename GemmConfig::AccumulatorsPerWarp WarpGemmShape;
/// Warp delta
typedef typename ShapeDiv<
typename GemmConfig::OutputTile,
WarpGemmShape>::Shape WarpDelta;
/// Thread-block scoped tile
typedef typename cutlass::ShapeMul<
WarpGemmShape,
WarpDelta
>::Shape OutputTile;
/// Multiply-add operation
typedef MultiplyAdd_ MultiplyAdd;
/// Epilogue functor
typedef EpilogueFunctor_ Functor;
/// Traits for global tile access
typedef cutlass::gemm::Volta884EpilogueGlobalTileTraits<
WarpGemmShape,
WarpDelta,
1,
typename MultiplyAdd::ScalarC
> EpilogueGlobalTileTraits;
/// Iterator to load a slice of the C matrix from global memory
typedef cutlass::TileLoadIterator<
EpilogueGlobalTileTraits,
typename GemmConfig::ScalarC,
cutlass::IteratorAdvance::kW,
cutlass::MemorySpace::kGlobal
> TileLoadIteratorC;
/// Conversion from C data type to accumulator data type
typedef Convert<
typename TileLoadIteratorC::Fragment,
Fragment<typename MultiplyAdd::ScalarC, TileLoadIteratorC::Fragment::kElements>
> ConvertSourceFragment;
/// Iterator to store a slice of the D matrix to global memory
typedef cutlass::TileStoreIterator<
EpilogueGlobalTileTraits,
typename GemmConfig::ScalarD,
cutlass::IteratorAdvance::kW,
cutlass::MemorySpace::kGlobal
> TileStoreIteratorD;
/// Conversion from accumulator data type to D data type
typedef Convert<
Fragment<typename MultiplyAdd::ScalarC, TileStoreIteratorD::Fragment::kElements>,
typename TileStoreIteratorD::Fragment
> ConvertDestinationFragment;
/// Defines traits for an epilogue of a Volta884 GEMM
typedef cutlass::gemm::Volta884EpilogueTraits<
OutputTile,
WarpGemmShape,
WarpDelta,
typename MultiplyAdd::Accumulators,
cutlass::gemm::Volta884SelectAccumulators<
WarpGemmShape,
WarpDelta,
typename MultiplyAdd::ScalarC
>,
cutlass::PredicatedTileLoadStream<
TileLoadIteratorC,
cutlass::gemm::Volta884EpiloguePredicateFunctor<
EpilogueGlobalTileTraits,
typename MultiplyAdd::ScalarC>,
ConvertSourceFragment
>,
cutlass::PredicatedTileStoreStream<
TileStoreIteratorD,
cutlass::gemm::Volta884EpiloguePredicateFunctor<
EpilogueGlobalTileTraits,
typename MultiplyAdd::ScalarC>,
ConvertDestinationFragment
>,
cutlass::TileStoreStream<
cutlass::gemm::Volta884EpilogueSharedStoreIterator<
WarpGemmShape,
WarpDelta,
typename MultiplyAdd::ScalarC,
typename MultiplyAdd::ScalarC
>
>,
cutlass::TileLoadStream<
cutlass::gemm::Volta884EpilogueSharedLoadIterator<
WarpGemmShape,
WarpDelta,
typename MultiplyAdd::ScalarC,
1,
typename MultiplyAdd::ScalarC
>
>,
Functor
> EpilogueTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
// clang-format on

View File

@ -0,0 +1,585 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
*/
#pragma once
// clang-format off
#include "cutlass/gemm/clear_accumulators.h"
#include "cutlass/gemm/gemm_config.h"
#include "cutlass/gemm/gemm_global_stream.h"
#include "cutlass/gemm/gemm_stream_pair.h"
#include "cutlass/gemm/threadblock_swizzle.h"
#include "cutlass/gemm/linear_scaling.h"
#include "cutlass/kernel_launch.h"
#include "cutlass/gemm/gemm_desc.h"
#include "cutlass/gemm/volta884_multiplicand.h"
#include "cutlass/gemm/volta884_multiply_add.h"
#include "cutlass/gemm/mma_global_stream.h"
#include "cutlass/gemm/mma_shared_stream.h"
#include "cutlass/gemm/volta884_gemm_epilogue_traits.h"
#include "cutlass/gemm/mma_epilogue.h"
#include "cutlass/gemm/gemm_mainloop.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines configuration for Volta884 GEMM
template <
/// The layout for A.
MatrixLayout::Kind LayoutA,
/// The layout for B.
MatrixLayout::Kind LayoutB,
/// The tile size for the GEMM KxNxM.
typename OutputTile_,
/// Tile size for warp-level GEMM (K-by-N-by-M)
typename WarpGemmShape_,
/// The accumulator type.
typename Accumulator_,
/// The source matrix type type.
typename ScalarC_,
/// The destination matrix type
typename ScalarD_,
/// Number of stages in shared memory
int StageCount,
/// If true, kernel is launched with CUDA launch bounds specified
bool kLaunchBounds = true,
/// If true, residue is computed in mainloop. If false, separate loops are instantiated.
bool kResidueSeparate = true,
/// Is residue performed in prologue?
bool kResidueInProlog = false>
struct Volta884GemmConfig : public GemmConfig<
/// The scalar type for A.
half,
/// The scalar type for B.
half,
/// The scalar type for C.
ScalarC_,
/// The scalar type for D.
ScalarD_,
/// The threadblock tile size
OutputTile_,
/// The functor to do the math in the main loop.
Volta884MultiplyAdd<WarpGemmShape_,
LayoutA,
half,
LayoutB,
half,
Accumulator_>,
/// The number of scalars per LDG for A.
8,
/// The number of scalars per STS for A.
8,
/// The number of scalars per LDS for A.
8,
/// The number of scalars per LDG for B.
8,
/// The number of scalars per STS for B.
8,
/// The number of scalars per LDS for B.
8,
/// The number of scalars per LDG for C and STG for D.
16 / int(sizeof(ScalarD_)),
/// The number of scalars per STS for D.
16 / int(sizeof(ScalarD_)),
/// The number of scalars per LDS for D.
16 / int(sizeof(ScalarD_)),
/// The number of stages in shared memory.
StageCount,
/// If true, separate mainloop is instantiated
kResidueSeparate,
/// If true, compute residue in prolog
kResidueInProlog,
/// Launch bounds not used
kLaunchBounds> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines components of Volta884 GEMM
template <
/// The layout for A.
MatrixLayout::Kind LayoutA,
/// The layout for B.
MatrixLayout::Kind LayoutB,
/// The tile size for the GEMM KxNxM.
typename OutputTile_,
/// Tile size for warp-level GEMM (K-by-N-by-M)
typename WarpGemmShape_,
/// The accumulator type.
typename Accumulator_,
/// The input matrix type type.
typename ScalarC_,
/// The output matrix type type.
typename ScalarD_,
/// Number of buffers in shared memory to use
int StageCount,
/// The functor to do the math in the epilogue.
typename EpilogueFunctor_ = LinearScaling<Accumulator_>,
/// The block swizzle to reorganize the grid.
typename BlockSwizzle_ = IdentityBlockSwizzle,
/// Selectively enables launch bounds
bool LaunchBounds = false
>
struct Volta884GemmTraits {
/// This traits
typedef Volta884GemmTraits<
LayoutA,
LayoutB,
OutputTile_,
WarpGemmShape_,
Accumulator_,
ScalarC_,
ScalarD_,
StageCount,
EpilogueFunctor_,
BlockSwizzle_,
LaunchBounds> This_;
/// The struct that consumes this Traits
typedef typename cutlass::gemm::GemmMainloop<This_> KernelClass;
/// Layout of multiplicand A matrix
static MatrixLayout::Kind const kLayoutA = LayoutA;
/// Layout of multiplicand B matrix
static MatrixLayout::Kind const kLayoutB = LayoutB;
/// Dimensions of threadblock tile (concept Shape)
typedef OutputTile_ OutputTile;
/// Shape of warp-level accumulators
typedef WarpGemmShape_ WarpGemmShape;
/// Multiplicand A scalar type
typedef half ScalarA;
/// Multiplicand B scalar type
typedef half ScalarB;
/// Data type of internal accumulator
typedef Accumulator_ Accumulator;
/// Data type of input accumulator matrix operand
typedef ScalarC_ ScalarC;
/// Data type of output accumulator matrix operand
typedef ScalarD_ ScalarD;
/// Shape of individual mma.sync instruction
typedef Shape<4, 16, 16> InstructionShape;
/// Tile size for an individual warp-level multiply-add
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
/// Defines properties about GEMM needed by host code
typedef Volta884GemmConfig<kLayoutA,
kLayoutB,
OutputTile,
WarpGemmShape,
Accumulator,
ScalarC,
ScalarD,
StageCount,
LaunchBounds>
GemmConfig;
//
// Derived types
//
/// Index type
typedef int Index;
/// Partitioning of threadblock into warps
typedef typename ShapeDiv<OutputTile, WarpGemmShape>::Shape WarpDelta;
/// Number of warps per threadblock
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
/// Defines iterators for A matrix
typedef Volta884Multiplicand<GemmOperand::kA, kLayoutA, OutputTile, WarpTile, kWarpCount, WarpDelta>
MultiplicandA;
/// Defines iterators for B matrix
typedef Volta884Multiplicand<GemmOperand::kB, kLayoutB, OutputTile, WarpTile, kWarpCount, WarpDelta>
MultiplicandB;
//
// GemmTraits mandatory type definitions
//
/// Maps hardware threadblocks to logical partitions of the GEMM
typedef BlockSwizzle_ BlockSwizzle;
/// Clears accumulators
typedef ClearAccumulators<ScalarC> ClearAccumulators;
/// Loads multiplicands from global memory
typedef GlobalLoadStreamPair<
MMAGlobalLoadStream<GemmOperand::kA,
kLayoutA,
typename MultiplicandA::LoadIterator,
Copy<typename MultiplicandA::LoadIterator::Fragment>,
typename MultiplicandA::StoreIterator,
StageCount>,
MMAGlobalLoadStream<GemmOperand::kB,
kLayoutB,
typename MultiplicandB::LoadIterator,
Copy<typename MultiplicandB::LoadIterator::Fragment>,
typename MultiplicandB::StoreIterator,
StageCount>,
GemmConfig::kResidueInProlog >
GlobalLoadStream;
/// Memory needed to store the threadblock-scoped GEMM tile
typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
union MainLoopStorage {
/// Stores the threadblock tile
ThreadblockTileStorage threadblock_tile;
/// Storage for GEMM global stream
typename GlobalLoadStream::SharedStorage global_to_shared_stream;
};
/// Loads multiplicands from shared memory
typedef SharedStreamPair<
MMASharedLoadStream<typename MultiplicandA::WarpLoadIterator,
Copy<typename MultiplicandA::WarpLoadIterator::Fragment>,
StageCount>,
MMASharedLoadStream<typename MultiplicandB::WarpLoadIterator,
Copy<typename MultiplicandB::WarpLoadIterator::Fragment>,
StageCount> >
SharedStream;
// Multiply-add object specialized for Volta mma.sync
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
#if 0
/// Naive epilogue for updating the output matrix
typedef cutlass::gemm::Volta884NaiveEpilogue<ScalarC,
typename MultiplicandA::WarpDelta,
typename MultiplyAdd::Iterations>
Epilogue;
#else
/// Efficient epilogue
typedef cutlass::gemm::MMAEpilogue<
typename Volta884GemmEpilogueTraitsHelper<
GemmConfig,
EpilogueFunctor_
>::EpilogueTraits
> Epilogue;
#endif
/// Parameters structure
struct Params : public KernelLaunchConfiguration {
/// The dimensions of the GEMM.
GemmCoord problem_size;
/// The K range for every partition except the last one
int partitionK_range;
/// The params for the global load stream
typename GlobalLoadStream::Params global_to_shared_stream;
/// The params for the shared load stream
typename SharedStream::Params shared_stream;
/// The params for the epilogue.
typename Epilogue::Params epilogue;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
/// Initialize the parameters.
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE Params(GemmDesc_ const& desc) {
initialize(desc);
}
/// Initialize the Params struct
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
// Problem size
problem_size = desc.problem_size;
// there is no partitionK in the default case
partitionK_range = problem_size[0];
// Compute grid dimensions
BlockSwizzle block_swizzle;
this->block = dim3(GemmConfig::kThreads);
this->grid = block_swizzle.get_grid_layout(
problem_size,
make_Coord_from_shape<OutputTile>());
// Compute offset to residue
Index gemm_k = problem_size[0];
Index offset_to_residue = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
Index offset_to_residue_last_partition = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
// Initialize parameters objects for
global_to_shared_stream.stream_a.initialize(
desc.A,
desc.batch_stride_A,
offset_to_residue,
offset_to_residue_last_partition);
global_to_shared_stream.stream_b.initialize(
desc.B,
desc.batch_stride_B,
offset_to_residue,
offset_to_residue_last_partition);
// The epilogue.
epilogue.initialize(desc);
return 0;
}
/// Helper to construct a GEMM params using a BLAS-like API
CUTLASS_HOST_DEVICE int initialize(Index m,
Index n,
Index k,
typename Epilogue::Scalar alpha,
ScalarA const* d_a,
Index lda,
ScalarB const* d_b,
Index ldb,
typename Epilogue::Scalar beta,
ScalarC const* d_c,
Index ldc,
ScalarD* d_d,
Index ldd) {
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
GemmCoord(k, n, m, 1),
alpha,
TensorRef<ScalarA const, 2>(d_a, lda),
TensorRef<ScalarB const, 2>(d_b, ldb),
beta,
TensorRef<ScalarC const, 2>(d_c, ldc),
TensorRef<ScalarD, 2>(d_d, ldd)
);
return this->initialize(desc);
}
/// Helper to construct a batched GEMM params
CUTLASS_HOST_DEVICE int initialize(Index m,
Index n,
Index k,
typename Epilogue::Scalar alpha,
ScalarA const* d_a,
Index lda,
long long int batch_stride_A,
ScalarB const* d_b,
Index ldb,
long long int batch_stride_B,
typename Epilogue::Scalar beta,
ScalarC const* d_c,
Index ldc,
long long int batch_stride_C,
ScalarD* d_d,
Index ldd,
long long int batch_stride_D,
Index batch_count) {
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
make_Coord(k, n, m, batch_count),
alpha,
TensorRef<ScalarA const, 2>(d_a, lda),
batch_stride_A,
TensorRef<ScalarB const, 2>(d_b, ldb),
batch_stride_B,
beta,
TensorRef<ScalarC const, 2>(d_c, ldc),
batch_stride_C,
TensorRef<ScalarD, 2>(d_d, ldd),
batch_stride_D
);
return this->initialize(desc);
}
/// Helper to construct a partitionedK GEMM params
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_, Index partitionK_multiple_ = 1) {
// partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
// the problem_size of each batch is (lastK_size, n, m)
// add more comments here
// the k range for every batch excpet the last one
partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
partitionK_range = partitionK_range - (partitionK_range % partitionK_multiple_);
// the k range of the last batch
// int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
assert((partitionK_range % partitionK_multiple_) == 0);
assert(partitionK_range > 0);
assert((lastK_range % partitionK_multiple_) == 0);
assert(lastK_range > 0);
int k_size = lastK_range;
int lda = partitonK_desc.A.stride(0);
int ldb = partitonK_desc.B.stride(0);
int ldc = partitonK_desc.C.stride(0);
int ldd = partitonK_desc.D.stride(0);
int n = partitonK_desc.problem_size.n();
long long int batch_stride_A = (kLayoutA == cutlass::MatrixLayout::kColumnMajor) ? lda * partitionK_range : partitionK_range;
long long int batch_stride_B = (kLayoutB == cutlass::MatrixLayout::kColumnMajor) ? partitionK_range : partitionK_range * ldb;
long long int batch_stride_C = ldc * n;
long long int batch_stride_D = ldd * n;
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
//we pass lastK_size as per batch K. there is also a range that will match partitionK_size
GemmCoord(k_size, partitonK_desc.problem_size.n(), partitonK_desc.problem_size.m(), partitionK_count_),
partitonK_desc.alpha,
partitonK_desc.A,
batch_stride_A,
partitonK_desc.B,
batch_stride_B,
partitonK_desc.beta,
partitonK_desc.C,
batch_stride_C,
partitonK_desc.D,
batch_stride_D
);
// Set the problem size.
problem_size = desc.problem_size;
// Compute grid dimensions
BlockSwizzle block_swizzle;
this->block = dim3(GemmConfig::kThreads);
this->grid = block_swizzle.get_grid_layout(
problem_size,
make_Coord_from_shape<OutputTile>());
// Compute offset to residue.
// partitionK_range <= problem_size[0]
Index gemm_k = problem_size[0];
Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
// Initialize parameters objects for
int error_code = global_to_shared_stream.stream_a.initialize(
desc.A,
desc.batch_stride_A,
offset_to_residue,
offset_to_residue_last_partition
);
if (error_code) {
return error_code;
}
error_code = global_to_shared_stream.stream_b.initialize(
desc.B,
desc.batch_stride_B,
offset_to_residue,
offset_to_residue_last_partition
);
if (error_code) {
return error_code;
}
// The epilogue.
return epilogue.initialize(desc);
}
/// Helper to construct a partitionedK GEMM params
CUTLASS_HOST_DEVICE int initialize(Index m,
Index n,
Index k,
typename Epilogue::Scalar alpha,
ScalarA const* d_a,
Index lda,
ScalarB const* d_b,
Index ldb,
typename Epilogue::Scalar beta,
ScalarC const* d_c,
Index ldc,
ScalarD* d_d,
Index ldd,
Index partitionK_count_,
Index partitionK_multiple_ = 1) {
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
GemmCoord(k, n, m, 1),
alpha,
TensorRef<ScalarA const, 2>(d_a, lda),
TensorRef<ScalarB const, 2>(d_b, ldb),
beta,
TensorRef<ScalarC const, 2>(d_c, ldc),
TensorRef<ScalarD, 2>(d_d, ldd)
);
return this->initialize(desc, partitionK_count_, partitionK_multiple_);
}
};
/// Shared memory storage
union SharedStorage {
/// Storage required during mainloop phase
MainLoopStorage main_loop;
/// Shared storage needed for epilogue
typename Epilogue::SharedStorage epilogue;
};
/// The memory fence for shared loads.
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
if (StageCount < 2) {
__syncthreads();
}
}
/// The memory fence for shared stores.
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) {
__syncthreads();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
// clang-format on

View File

@ -0,0 +1,298 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/tile_iterator.h"
#include "cutlass/util/platform.h"
#include "cutlass/gemm/mma_global_tile.h"
#include "cutlass/gemm/volta884_shared_tile.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines iterators for loading and storing multiplicands
template <
/// Identifies multiplicand of GEMM (A or B)
GemmOperand::Kind Operand,
/// Specifies layout of data in source memory
MatrixLayout::Kind Layout,
/// Specifies threadblock tile shape
typename Tile,
/// Specifies warp tile shape
typename WarpTile,
/// Specifies the number of participating warps
int WarpCount,
/// Specifies the delta between warp tiles
typename WarpDelta_>
struct Volta884Multiplicand;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines iterators for loading and storing multiplicands for A.column_major
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
struct Volta884Multiplicand<GemmOperand::kA,
MatrixLayout::kColumnMajor,
Tile_,
WarpTile_,
WarpCount,
WarpDelta_> {
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = GemmOperand::kA;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
/// Thread-block tile shape
typedef Tile_ Tile;
/// Warp-level matrix multiply-add shape
typedef WarpTile_ WarpTile;
/// Total number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp tiles
typedef WarpDelta_ WarpDelta;
//
// Thread-block load iterator
//
typedef
typename MMAThreadblockCongruousLoad<kOperand, Tile_, WarpCount, WarpDelta::kW>::Iterator
LoadIterator;
//
// Thread-block store iterator
//
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
kLayout,
Tile_,
WarpCount,
WarpDelta::kW>
StoreIterator;
//
// Warp-level load iterator
//
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
kLayout,
Tile_,
WarpTile_,
WarpCount,
WarpDelta>
WarpLoadIterator;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines iterators for loading and storing multiplicands for B.row_major
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
struct Volta884Multiplicand<GemmOperand::kB,
MatrixLayout::kRowMajor,
Tile_,
WarpTile_,
WarpCount,
WarpDelta_> {
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = GemmOperand::kB;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
/// Thread-block tile shape
typedef Tile_ Tile;
/// Warp-level matrix multiply-add shape
typedef WarpTile_ WarpTile;
/// Total number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp tiles
typedef WarpDelta_ WarpDelta;
//
// Thread-block load iterator
//
typedef
typename MMAThreadblockCongruousLoad<kOperand, Tile_, WarpCount, WarpDelta::kH>::Iterator
LoadIterator;
//
// Thread-block store iterator
//
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
kLayout,
Tile_,
WarpCount,
WarpDelta::kH>
StoreIterator;
//
// Warp-level load iterator
//
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
kLayout,
Tile_,
WarpTile_,
WarpCount,
WarpDelta>
WarpLoadIterator;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines iterators for loading and storing multiplicands for A.row_major
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
struct Volta884Multiplicand<GemmOperand::kA,
MatrixLayout::kRowMajor,
Tile_,
WarpTile_,
WarpCount,
WarpDelta_> {
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = GemmOperand::kA;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
/// Thread-block tile shape
typedef Tile_ Tile;
/// Warp-level matrix multiply-add shape
typedef WarpTile_ WarpTile;
/// Total number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp tiles
typedef WarpDelta_ WarpDelta;
//
// Thread-block load iterator
//
typedef
typename MMAThreadblockCrosswiseLoad<kOperand, Tile_, WarpCount, WarpDelta::kW>::Iterator
LoadIterator;
//
// Thread-block store iterator
//
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
kLayout,
Tile_,
WarpCount,
WarpDelta::kW>
StoreIterator;
//
// Warp-level load iterator
//
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
kLayout,
Tile_,
WarpTile_,
WarpCount,
WarpDelta>
WarpLoadIterator;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines iterators for loading and storing multiplicands for B.row_major
template <typename Tile_, typename WarpTile_, int WarpCount, typename WarpDelta_>
struct Volta884Multiplicand<GemmOperand::kB,
MatrixLayout::kColumnMajor,
Tile_,
WarpTile_,
WarpCount,
WarpDelta_> {
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = GemmOperand::kB;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
/// Thread-block tile shape
typedef Tile_ Tile;
/// Warp-level matrix multiply-add shape
typedef WarpTile_ WarpTile;
/// Total number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp tiles
typedef WarpDelta_ WarpDelta;
//
// Thread-block load iterator
//
typedef
typename MMAThreadblockCrosswiseLoad<kOperand, Tile_, WarpCount, WarpDelta::kH>::Iterator
LoadIterator;
//
// Thread-block store iterator
//
typedef Volta884ThreadblockMultiplicandStoreIterator<kOperand,
kLayout,
Tile_,
WarpCount,
WarpDelta::kH>
StoreIterator;
//
// Warp-level load iterator
//
typedef Volta884WarpMultiplicandLoadIterator<kOperand,
kLayout,
Tile_,
WarpTile_,
WarpCount,
WarpDelta>
WarpLoadIterator;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,704 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements warp-level multiply-accumulate operations using Volta's mma.sync instruction
*/
#pragma once
#include "cutlass/arch/mma.h"
#include "cutlass/fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Shape of a warp-level GEMM (K-by-N-by-M)
typename WarpGemmShape_,
/// Layout of A multiplicand
MatrixLayout::Kind LayoutA,
/// Data type of A multiplicand
typename ScalarA_,
/// Layout of B multiplicand
MatrixLayout::Kind LayoutB,
/// Data type of A multiplicand
typename ScalarB_,
/// Data type of accumulators
typename ScalarC_>
struct Volta884MultiplyAdd {
//
// Constant and type definitions
//
/// Shape of a warp-level GEMM (K-by-N-by-M)
typedef WarpGemmShape_ WarpGemmShape;
/// Shape of a warp-level GEMM (K-by-N-by-M)
typedef WarpGemmShape_ AccumulatorsPerWarp;
/// Most of the Volta884 code assumes interleaved 32x32 tiles
typedef Shape<4, 32, 32> InterleavedTileShape;
/// Shape of an individual warp-wide Volta mma.sync instruction
typedef Shape<4, 16, 16> InstructionShape;
/// Shape of a warp-level matrix multiply operation
typedef Shape<InstructionShape::kD, WarpGemmShape::kH, WarpGemmShape::kW> WarpTile;
/// Verify WarpTile is a multiple of fundamental 32x32 interleaved tile
static_assert(!(WarpTile::kH % InterleavedTileShape::kH) &&
!(WarpTile::kW % InterleavedTileShape::kW) && WarpTile::kD == 4,
"WarpTile must be a multiple of InterleavedTileShape.");
/// Layout of A multiplicand
static MatrixLayout::Kind const kLayoutA = LayoutA;
/// Layout of B multiplicand
static MatrixLayout::Kind const kLayoutB = LayoutB;
/// The type for A.
typedef ScalarA_ ScalarA;
/// The type for B.
typedef ScalarB_ ScalarB;
/// The type for C and D.
typedef ScalarC_ ScalarC;
/// Hard-coded comptue type supported on Volta
static arch::ComputeType::Kind const kComputeType = arch::ComputeType::kDefault;
/// Defines a warp-level matrix multiply-accumulate operation performed by a warp.
//
// The layout is as follows. The entire warp performs a 64x64x4 GEMM using Volta mma.sync macros
// arranged as a 2x2 tile of adjacent, 32x32x4 matrix products. These are implemented as a
// 2x2 arrangement of spatially interleaved Volta mma.sync macros.
//
// The Iterations shape maps to the following dimensions of the above warp-level GEMM:
//
// kC: number of rows of Volta mma.sync macros in 32x32x4 tile
// kW: number of columns of Volta mma.sync macros in 32x32x4 tile
// kH: number of rows of 32x32x4 macros in larger 64x64x4 tile
// kD: number of columns of 32x32x4 macros in larger 64x64x4 tile
//
// A column-major ordering would arrange C and H as the inner-most loops, with W and D as the
// outer-most.
//
typedef Shape<WarpTile::kH / InterleavedTileShape::kH,
WarpTile::kW / InterleavedTileShape::kW,
InterleavedTileShape::kH / InstructionShape::kH,
InterleavedTileShape::kW / InstructionShape::kW>
Iterations;
/// Number of multiplicand elements per instruction
static int const kMultElementsPerInst = 4;
/// Number of multiplicand elements per instruction
static int const kAccumElementsPerInst = 8;
/// Fragment definition for A multiplicand
typedef Fragment<ScalarA, Iterations::kH * Iterations::kC * kMultElementsPerInst> FragmentA;
/// Fragment definition for B multiplicand
typedef Fragment<ScalarB, Iterations::kW * Iterations::kD * kMultElementsPerInst> FragmentB;
/// Fragment definition for accumulators
typedef Fragment<ScalarC, ShapeCount<Iterations>::kCount * kAccumElementsPerInst> Accumulators;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE Volta884MultiplyAdd() {}
/// Multiply : d = (-)a*b + c.
CUTLASS_DEVICE void multiply_add(FragmentA const& A,
FragmentB const& B,
Accumulators const& C,
Accumulators& D,
bool negate = false) {
// Guard conditional needed for __hneg2
#if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <= 750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) { // Outer column
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) { // Inner column
CUTLASS_PRAGMA_UNROLL
for (int h_raw = 0; h_raw < Iterations::kH; ++h_raw) { // Outer row
CUTLASS_PRAGMA_UNROLL
for (int c_raw = 0; c_raw < Iterations::kC; ++c_raw) { // Inner row
int op_col = (w + Iterations::kW * d);
// Column-major serpentine sequence to maximize reuse of B operand.
int h = h_raw;
int c = c_raw;
if (op_col & 1) {
h = Iterations::kH - h_raw - 1;
c = Iterations::kC - c_raw - 1;
}
int op_row = (c + Iterations::kC * h);
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
ScalarA operand_A[kMultElementsPerInst];
reinterpret_cast<uint64_t&>(operand_A[0]) =
reinterpret_cast<uint64_t const&>(A[op_row * kMultElementsPerInst]);
if (negate) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kMultElementsPerInst; i += 2) {
reinterpret_cast<half2&>(operand_A[i]) =
__hneg2(reinterpret_cast<half2 const&>(A[op_row * kMultElementsPerInst + i]));
}
}
// Issue a Volta mma.sync instruction
arch::mma<InstructionShape,
kLayoutA,
ScalarA,
kLayoutB,
ScalarB,
ScalarC,
kComputeType>(
operand_A, //&A[op_row * kMultElementsPerInst],
&B[op_col * kMultElementsPerInst],
&C[op_idx * kAccumElementsPerInst],
&D[op_idx * kAccumElementsPerInst]);
}
}
}
}
#endif // if (__CUDA_ARCH__ >= 700 && __CUDA_ARCH__ <=750 && CUTLASS_ENABLE_TENSOR_CORE_MMA)
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Accumulator, typename WarpDelta, typename Iterations>
struct Volta884NaiveEpilogue;
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Naive epilogue specialized for f32 accumulators - may be considered authoritative mapping of
/// accumulators to mma.sync operations.
template <typename WarpDelta_, typename Iterations_>
struct Volta884NaiveEpilogue<float, WarpDelta_, Iterations_> {
/// Accumulator data type
typedef float ScalarC;
/// Output accumulator type
typedef float ScalarD;
/// BLAS Scalar type
typedef float Scalar;
/// Delta among warp tiles
typedef WarpDelta_ WarpDelta;
/// Number of Volta mma.sync operations
typedef Iterations_ Iterations;
/// Most of the Volta884 code assumes interleaved 32x32 tiles
typedef Shape<4, 32, 32> InterleavedTileShape;
/// Number of multiplicand elements per instruction
static int const kAccumElementsPerInst = 8;
/// Fragment definition for accumulators
typedef Fragment<ScalarC, ShapeCount<Iterations>::kCount * kAccumElementsPerInst> Accumulators;
/// Params object
struct Params {
/// Pointer to output matrix
ScalarC* ptr;
/// stride
int ldm;
/// Scalar alpha
float alpha;
/// Scalar beta
float beta;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() : ptr(0), ldm(0), alpha(1), beta(0) {}
CUTLASS_HOST_DEVICE
Params(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0)
: ptr(_ptr), ldm(_ldm), alpha(_alpha), beta(_beta) {}
/// Initialize method
CUTLASS_HOST_DEVICE
int initialize(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) {
ptr = _ptr;
ldm = _ldm;
alpha = _alpha;
beta = _beta;
return 0;
}
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
ptr = reinterpret_cast<ScalarC*>(desc.D.data());
ldm = desc.D.leading_dim();
alpha = desc.alpha;
beta = desc.beta;
return 0;
}
};
/// Shared stoarge
struct SharedStorage {};
/// Helper used to compute initial offset for each thread
struct InitialOffset {
int row_offset;
int col_offset;
/// Constructor
CUTLASS_DEVICE
InitialOffset() {
int warp_id = (threadIdx.x >> 5);
int lane_id = (threadIdx.x & 0x1f);
int quad_id = (lane_id >> 2);
int quadpair_id = (quad_id & 0x3);
int quadpair_row = (quadpair_id & 1);
int quadpair_col = (quadpair_id >> 1);
int quad_hilo = (quad_id >> 2) & 1;
// compute initial offset
int warp_row_offset = (warp_id % WarpDelta::kW) * InterleavedTileShape::kW;
int warp_col_offset = (warp_id / WarpDelta::kW) * InterleavedTileShape::kH;
int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 1);
int thread_col_offset = (quadpair_col * 2) * 8 + (lane_id & 2);
row_offset = warp_row_offset + thread_row_offset;
col_offset = warp_col_offset + thread_col_offset;
}
};
//
// Data members
//
/// Parameters object
Params params;
/// Problem size
Coord<3> problem_size;
//
// Methods
//
/// Computes initial offset for each thread
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params,
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
: params(_params), problem_size(_problem_size) {}
/// Computes initial offset for each thread
CUTLASS_DEVICE Volta884NaiveEpilogue(ScalarC* _ptr,
int _ldm,
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
: params(_ptr, _ldm), problem_size(_problem_size) {}
/// Computes initial offset for each thread
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params,
SharedStorage& shared_storage,
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
: params(_params), problem_size(_problem_size) {}
/// Sets accumulators to zero
CUTLASS_DEVICE void clear(Accumulators& C) {
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
CUTLASS_PRAGMA_UNROLL
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
C[op_idx * kAccumElementsPerInst + reg] = 0;
}
}
}
}
}
}
/// Naive load operation for debugging
CUTLASS_DEVICE void load(Accumulators& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
InitialOffset initial;
initial.row_offset += threadblock_offset[2];
initial.col_offset += threadblock_offset[1];
ScalarC const* load_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
// loads accumulators
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
ScalarC const* op_ptr = load_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
CUTLASS_PRAGMA_UNROLL
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
int tr = (reg & 2) + c * 4;
int tc = (reg & 1) + (reg & 4) * 2 + w * 4;
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
if (row < problem_size[2] && column < problem_size[1]) {
C[op_idx * kAccumElementsPerInst + reg] = op_ptr[tr + tc * params.ldm];
}
}
}
}
}
}
}
/// Naive store operation for debugging
CUTLASS_DEVICE void store(Accumulators const& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
InitialOffset initial;
initial.row_offset += threadblock_offset[2];
initial.col_offset += threadblock_offset[1];
ScalarC* store_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
// store out accumulators
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
ScalarC* op_ptr = store_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
CUTLASS_PRAGMA_UNROLL
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
int tr = (reg & 2) + c * 4;
int tc = (reg & 1) + (reg & 4) * 2 + w * 4;
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
if (row < problem_size[2] && column < problem_size[1]) {
op_ptr[tr + tc * params.ldm] =
params.alpha * C[op_idx * kAccumElementsPerInst + reg] +
params.beta * op_ptr[tr + tc * params.ldm];
}
}
}
}
}
}
}
/// CUTLASS Epilogue interface
CUTLASS_DEVICE void epilogue(Accumulators const& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
store(C, threadblock_offset);
}
CUTLASS_DEVICE void epilogue(Accumulators& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
store(C, threadblock_offset);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Naive epilogue specialized for f16 accumulators - may be considered authoritative mapping of
/// accumulators to mma.sync operations.
template <typename WarpDelta_, typename Iterations_>
struct Volta884NaiveEpilogue<half, WarpDelta_, Iterations_> {
/// Accumulator data type
typedef half ScalarC;
/// Output accumulator type
typedef half ScalarD;
/// BLAS Scalar type
typedef half Scalar;
/// Delta among warp tiles
typedef WarpDelta_ WarpDelta;
/// Number of Volta mma.sync operations
typedef Iterations_ Iterations;
/// Most of the Volta884 code assumes interleaved 32x32 tiles
typedef Shape<4, 32, 32> InterleavedTileShape;
/// Number of multiplicand elements per instruction
static int const kAccumElementsPerInst = 8;
/// Fragment definition for accumulators
typedef Fragment<ScalarC, ShapeCount<Iterations>::kCount * kAccumElementsPerInst> Accumulators;
/// Params object
struct Params {
/// Pointer to output matrix
ScalarC* ptr;
/// stride
int ldm;
/// Scalar alpha
half alpha;
/// Scalar beta
half beta;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() : ptr(0), ldm(0), alpha(1), beta(0) {}
CUTLASS_HOST_DEVICE
Params(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0)
: ptr(_ptr), ldm(_ldm), alpha(_alpha), beta(_beta) {}
/// Initialize method
CUTLASS_HOST_DEVICE
int initialize(ScalarC* _ptr, int _ldm, float _alpha = 1, float _beta = 0) {
ptr = _ptr;
ldm = _ldm;
alpha = _alpha;
beta = _beta;
return 0;
}
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
ptr = reinterpret_cast<ScalarC*>(desc.D.data());
ldm = desc.D.leading_dim();
alpha = desc.alpha;
beta = desc.beta;
return 0;
}
};
/// Shared stoarge
struct SharedStorage {};
/// Helper used to compute initial offset for each thread
struct InitialOffset {
int row_offset;
int col_offset;
/// Constructor
CUTLASS_DEVICE
InitialOffset() {
int warp_id = (threadIdx.x >> 5);
int lane_id = (threadIdx.x & 0x1f);
int quad_id = (lane_id >> 2);
int quadpair_id = (quad_id & 0x3);
int quadpair_row = (quadpair_id & 1);
int quadpair_col = (quadpair_id >> 1);
int quad_hilo = (quad_id >> 2) & 1;
// compute initial offset
int warp_row_offset = (warp_id % WarpDelta::kW) * InterleavedTileShape::kW;
int warp_col_offset = (warp_id / WarpDelta::kW) * InterleavedTileShape::kH;
int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 3);
int thread_col_offset = (quadpair_col * 2) * 8;
row_offset = warp_row_offset + thread_row_offset;
col_offset = warp_col_offset + thread_col_offset;
}
};
//
// Data members
//
/// Parameters object
Params params;
/// Problem size
Coord<3> problem_size;
//
// Methods
//
/// Computes initial offset for each thread
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params)
: params(_params), problem_size(make_Coord(1024, 1024, 1024)) {}
/// Computes initial offset for each thread
CUTLASS_DEVICE Volta884NaiveEpilogue(ScalarC* _ptr, int _ldm)
: params(_ptr, _ldm), problem_size(make_Coord(1024, 1024, 1024)) {}
/// Computes initial offset for each thread
CUTLASS_DEVICE Volta884NaiveEpilogue(Params const& _params,
SharedStorage& shared_storage,
Coord<3> const& _problem_size = make_Coord(1024, 1024, 1024))
: params(_params), problem_size(_problem_size) {}
/// Sets accumulators to zero
CUTLASS_DEVICE void clear(Accumulators& C) { C.clear(); }
/// Naive load operation for debugging
CUTLASS_DEVICE void load(Accumulators& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
InitialOffset initial;
initial.row_offset += threadblock_offset[2];
initial.col_offset += threadblock_offset[1];
ScalarC const* load_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
// loads accumulators
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
ScalarC const* op_ptr = load_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
CUTLASS_PRAGMA_UNROLL
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
int tr = c * 4;
int tc = (reg & 3) + (reg & 4) * 2 + w * 4;
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
if (row < problem_size[2] && column < problem_size[1]) {
C[op_idx * kAccumElementsPerInst + reg] = op_ptr[tr + tc * params.ldm];
}
}
}
}
}
}
}
/// Naive store operation for debugging
CUTLASS_DEVICE void store(Accumulators const& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
InitialOffset initial;
initial.row_offset += threadblock_offset[2];
initial.col_offset += threadblock_offset[1];
ScalarC* store_ptr = params.ptr + initial.row_offset + params.ldm * initial.col_offset;
// store out accumulators
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
ScalarC* op_ptr = store_ptr + h * WarpDelta::kW * InterleavedTileShape::kW +
d * WarpDelta::kH * InterleavedTileShape::kH * params.ldm;
int op_idx = c + Iterations::kC * (w + Iterations::kW * (h + Iterations::kH * d));
CUTLASS_PRAGMA_UNROLL
for (int reg = 0; reg < kAccumElementsPerInst; ++reg) {
int tr = c * 4;
int tc = (reg & 3) + (reg & 4) * 2 + w * 4;
int row = initial.row_offset + h * WarpDelta::kW * InterleavedTileShape::kW + tr;
int column = initial.col_offset + d * WarpDelta::kH * InterleavedTileShape::kH + tc;
if (row < problem_size[2] && column < problem_size[1]) {
op_ptr[tr + tc * params.ldm] =
params.alpha * C[op_idx * kAccumElementsPerInst + reg] +
params.beta * op_ptr[tr + tc * params.ldm];
}
}
}
}
}
}
}
/// CUTLASS Epilogue interface
CUTLASS_DEVICE void epilogue(Accumulators const& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
store(C, threadblock_offset);
}
CUTLASS_DEVICE void epilogue(Accumulators& C,
Coord<3> const& threadblock_offset = make_Coord(0, 0, 0)) {
store(C, threadblock_offset);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,142 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/tile_iterator.h"
#include "cutlass/util/platform.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Warp-scoped shared memory load iterators
//
////////////////////////////////////////////////////////////////////////////////////////////////////
///! Iterator to store a thread-block scoped fragment to shared memory
template <
/// Identifies multiplicand of GEMM (A or B)
GemmOperand::Kind Operand,
/// Specifies layout of data in source memory
MatrixLayout::Kind Layout,
/// Specifies threadblock tile shape
typename Tile,
/// Specifies the number of participating warps
int WarpCount,
/// Specifies the delta between warp accesses along the outer dimension
int WarpDelta>
struct Volta884ThreadblockMultiplicandStoreIterator;
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Iterator to load a fragment for each warp-level tile
template <
/// Identifies multiplicand of GEMM (A or B)
GemmOperand::Kind Operand,
/// Specifies layout of data in source memory
MatrixLayout::Kind Layout,
/// Specifies threadblock tile shape
typename Tile,
/// Specifies the warp tile shape
typename WarpTile,
/// Specifies the number of participating warps
int WarpCount,
/// Specifies the delta between warp accesses along the outer dimension
typename WarpDelta>
struct Volta884WarpMultiplicandLoadIterator;
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Fully-specialized implementations extracted in separate headers.
//
#include "cutlass/gemm/volta884_shared_tile_contiguous.h"
#include "cutlass/gemm/volta884_shared_tile_crosswise.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Epilogue shared memory iterators
//
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Stores an accumulator fragment to shared memory
template <
/// Shape of warp-level GEMM
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// Data type of accumulator elements
typename Scalar_,
/// Data type of mma.sync accumulator - this is used to infer layout.
typename Accumulator_>
struct Volta884EpilogueSharedStoreIterator;
/// Loads an accumulator fragment from shared memory
template <
/// Shape of warp-level GEMM
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// Data type of accumulator elements
typename Scalar_,
/// Number of scalar elements loaded
int AccessSize_,
/// Data type of mma.sync accumulator - this is used to infer layout.
typename Accumulator_>
struct Volta884EpilogueSharedLoadIterator;
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
//
// Partially-specialized implementations extracted in separate header.
//
#include "cutlass/gemm/volta884_shared_tile_epilogue.h"
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,974 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
*/
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Congruous loading
//
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Store iterator specialized for A.column_major
template <
/// Specifies threadblock tile shape
typename Tile_,
/// Specifies the number of participating warps
int WarpCount,
/// Specifies the delta between warp accesses along the outer dimension
int WarpDelta>
struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kA,
MatrixLayout::kColumnMajor,
Tile_,
WarpCount,
WarpDelta> {
//
// Constant and type definitions
//
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = GemmOperand::kA;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
/// Shape of thread-block multiplicand
typedef Tile_ Tile;
/// Number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp accumulator tiles along the outer dimension
static int const kWarpDelta = WarpDelta;
/// This implementation is specialized for 128b loads
static int const kAccessSize = 8;
/// Swizzled store iterator
struct ThreadOffset {
__device__ Coord<4> operator()() const {
int warp_id = (threadIdx.x >> 5);
int lane_id = (threadIdx.x & 0x1f);
int k_idx = warp_id;
// This is an 8-element vector within one 32x32 tile
int vec_idx = lane_id & 3;
int vec_col = (vec_idx / 2);
int t4t3 = (lane_id >> 3);
int col_rotate = ((lane_id >> 1) & 2) | (lane_id & 1);
int t_col = (vec_col << 2) | (col_rotate ^ t4t3);
Coord<4> offset = make_Coord(k_idx, col_rotate, t_col, 0);
return offset;
}
};
/// Projects the threadblock tile
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
/// Stored tile has a structure designed for efficient MIO storing and loading
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
kAccessSize, // Eight banks of MIO
kAccessSize>
VectorizedShape; // 128b stores
/// Offset between stores
typedef Shape<WarpCount, 1, 1, 1> Delta;
/// Number of iterations
typedef Shape<(VectorizedShape::kD / WarpCount), (OperandShape::kW >> 6), 1, 1> Iterations;
/// Source tile traits
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
/// Scalar type
typedef half Scalar;
/// Index type
typedef int Index;
/// Index type
typedef int LongIndex;
//
// Derived types
//
/// Tensor reference
typedef TensorRef<Scalar, 4> TensorRef;
/// Predicate vector
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
/// Fragment definition
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
/// Elements loaded by one instruction
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
/// The fragment iterator.
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
/// The fragment const iterator.
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
/// Strides into expected SMEM tile
typedef typename ShapeStrides<VectorizedShape, 1>::Shape Strides;
/// Memory space access
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
/// Parameters object
struct Params {
//
// Data members
//
/// Pointer to element type
Scalar *pointer;
/// Strides
Coord<4> stride;
//
// Methods
//
/// Constructs a parameters object
CUTLASS_HOST_DEVICE
Params(Scalar *_pointer = 0)
: pointer(_pointer),
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
/// Constructs a params object from a TensorRef
CUTLASS_HOST_DEVICE
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
};
//
// Data members
//
/// Parameters object
Params params;
//
// Methods
//
/// Constructs a store iterator
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator(
Params const &_params,
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
ThreadOffset offset_func = ThreadOffset())
: params(_params) {
// Compute initial thread offset
Coord<4> offset = offset_func();
params.pointer += (_block_offset + offset).template dot<int>(params.stride);
}
/// Stores a fragment
CUTLASS_DEVICE void store(Fragment const &fragment,
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
FragmentConstIterator frag_iterator(fragment);
// Iterate over each store
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
int idx = w + Iterations::kW * h;
int row = idx * 4;
Coord<4> sts_offset =
make_Coord(d, row, 0, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
Store<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::store(
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, 0)),
params.pointer,
params.stride.template dot<int>(sts_offset + offset));
}
}
}
}
/// Increments store iterator to next tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) {
params.pointer +=
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
return *this;
}
/// Increments to next tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(); }
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) {
return increment(count);
}
/// Increments store iterator to previous tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) {
params.pointer -=
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
return *this;
}
/// Increments to subsequent tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(); }
/// Decrements to previous tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) {
return decrement(count);
}
/// Stores a fragment and increments in the K dimension
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &store_post_increment(
Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
store(fragment, offset);
return increment();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Iterator to load a fragment for each warp-level tile specialized for A.column_major
template <
/// Specifies threadblock tile shape
typename Tile_,
/// Specifies the warp tile shape
typename WarpTile_,
/// Specifies the number of participating warps
int WarpCount,
/// Specifies the delta between warp accesses along the outer dimension
typename WarpDelta_>
struct Volta884WarpMultiplicandLoadIterator<GemmOperand::kA,
MatrixLayout::kColumnMajor,
Tile_,
WarpTile_,
WarpCount,
WarpDelta_> {
//
// Constant and type definitions
//
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = GemmOperand::kA;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
/// Shape of thread-block multiplicand
typedef Tile_ Tile;
/// Shape of warp-tile matrix operation
typedef WarpTile_ WarpTile;
/// Hard-coded tile shape
typedef Shape<4, 32, 32> InterleavedTileShape;
/// Number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp accumulator tiles along the outer dimension
typedef WarpDelta_ WarpDelta;
/// Two SMEM read pointers are needed
static int const kPointerCount = (WarpDelta::kW == 1 ? 2 : 1);
/// This implementation is specialized for 128b loads
static int const kAccessSize = 8;
/// Swizzled store iterator
struct ThreadOffset {
/// Compute thread offset coordinate for each pointer
CUTLASS_DEVICE Coord<4> operator()(int pointer_idx = 0) const {
// Determine the warp's reading location within the SMEM tile
int warp_id = ((threadIdx.x >> 5) % WarpDelta::kW);
// This is an 8-element vector within one 32x32 tile
int lane_id = (threadIdx.x & 0x1f);
int vec_row = (lane_id >> 4);
int vec_col = ((lane_id & 4) >> 2);
int tile_row = pointer_idx * 2 + vec_row;
// Column rotation function
int t_col = (vec_col * 4);
if (pointer_idx == 1 || (WarpDelta::kW > 1 && (warp_id & 1))) {
vec_row |= 2;
}
t_col = t_col | ((lane_id & 3) ^ vec_row);
Coord<4> offset = make_Coord(0, warp_id * 2 + tile_row, t_col, 0);
return offset;
}
};
/// Projects the threadblock tile
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
/// Stored tile has a structure designed for efficient MIO storing and loading
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
kAccessSize, // Eight banks of MIO
kAccessSize>
VectorizedShape; // 128b stores
/// Offset between acceses
typedef typename platform::conditional<WarpDelta::kW == 1,
Shape<1, 0, 0, 0>,
Shape<1, 2 * WarpDelta::kW, 0, 0> >::type Delta;
/// Number of iterations
typedef Shape<1, WarpTile::kW / InterleavedTileShape::kW, 1, 1> Iterations;
/// Source tile traits
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
/// Scalar type
typedef half Scalar;
/// Index type
typedef int Index;
/// Index type
typedef int LongIndex;
//
// Derived types
//
/// Tensor reference
typedef TensorRef<Scalar, 4> TensorRef;
/// Predicate vector
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
/// Fragment definition
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
/// Elements loaded by one instruction
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
/// The fragment iterator.
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
/// The fragment const iterator.
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
/// Strides into expected SMEM tile
typedef typename ShapeStrides<VectorizedShape, kAccessSize>::Shape Strides;
/// Memory space access
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
/// Parameters object
struct Params {
//
// Data members
//
/// Base pointer to SMEM allocation
Scalar const *pointer;
/// SMEM strides
Coord<4> stride;
//
// Methods
//
/// Constructs a parameters object
CUTLASS_HOST_DEVICE
Params(Scalar const *_pointer = 0)
: pointer(_pointer),
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
/// Constructs a params object from a TensorRef
CUTLASS_HOST_DEVICE
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
};
//
// Data members
//
// A.column requires two SMEM pointers.
// Because Params only supplies a base pointer and strides, there is no usual params
// data member. Instead, it is used to initialize the following.
/// Pointer to SMEM allocation.
Scalar const *pointer[kPointerCount];
/// SMEM strides
Coord<4> stride;
//
// Methods
//
/// Constructs a load iterator
CUTLASS_DEVICE Volta884WarpMultiplicandLoadIterator(
Params const &_params,
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
ThreadOffset offset_func = ThreadOffset())
: stride(_params.stride) {
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < kPointerCount; ++idx) {
Coord<4> offset = offset_func(idx);
pointer[idx] = _params.pointer + (_block_offset + offset).template dot<int>(stride);
}
}
/// Loads a fragment
CUTLASS_DEVICE void load(Fragment &fragment,
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
FragmentIterator frag_iterator(fragment);
// Iterate over each load
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
// Pointers mapped to Iterations::kH dimension
Scalar const *_pointer = pointer[(kPointerCount == 2 ? h : 0)];
Coord<4> lds_offset =
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
Load<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::load(
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)),
_pointer,
stride.template dot<int>(lds_offset + offset));
}
}
}
}
/// Loads a fragment and increments to next K-index
CUTLASS_DEVICE void load_post_increment(Fragment &fragment,
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
load(fragment, offset);
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
pointer[ptr_idx] += make_Coord(1, 0, 0, 0).template dot<int>(stride);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Store iterator specialized for B.row_major
template <
/// Specifies threadblock tile shape
typename Tile_,
/// Specifies the number of participating warps
int WarpCount,
/// Specifies the delta between warp accesses along the outer dimension
int WarpDelta>
struct Volta884ThreadblockMultiplicandStoreIterator<GemmOperand::kB,
MatrixLayout::kRowMajor,
Tile_,
WarpCount,
WarpDelta> {
//
// Constant and type definitions
//
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = GemmOperand::kB;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
/// Shape of thread-block multiplicand
typedef Tile_ Tile;
/// Number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp accumulator tiles along the outer dimension
static int const kWarpDelta = WarpDelta;
/// This implementation is specialized for 128b loads
static int const kAccessSize = 8;
/// Index type
typedef int Index;
/// Index type
typedef int LongIndex;
/// Swizzled store iterator
struct ThreadOffset {
CUTLASS_DEVICE Coord<4> operator()() const {
int warp_id = (threadIdx.x >> 5);
int lane_id = (threadIdx.x & 0x1f);
int k_idx = warp_id;
// This is an 8-element vector within one 32x32 tile
int vec_idx = lane_id & 3;
int vec_col = (vec_idx / 2);
int t4t3 = (lane_id >> 3);
int col_rotate = ((lane_id >> 1) & 2) | (lane_id & 1);
int t_col = (vec_col << 2) | (col_rotate ^ t4t3);
Coord<4> offset = make_Coord(k_idx, col_rotate , t_col, 0);
return offset;
}
};
/// Projects the threadblock tile
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
/// Stored tile has a structure designed for efficient MIO storing and loading
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
kAccessSize, // Eight banks of MIO
kAccessSize>
VectorizedShape; // 128b stores
/// Offset between stores
typedef Shape<WarpCount, 1, 1, 1> Delta;
/// Number of iterations
typedef Shape<(VectorizedShape::kD / WarpCount), (OperandShape::kW >> 6), 1, 1> Iterations;
/// Source tile traits
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
/// Scalar type
typedef half Scalar;
//
// Derived types
//
/// Tensor reference
typedef TensorRef<Scalar, 4> TensorRef;
/// Predicate vector
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
/// Fragment definition
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
/// Elements loaded by one instruction
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
/// The fragment iterator.
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
/// The fragment const iterator.
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
/// Strides into expected SMEM tile
typedef typename ShapeStrides<VectorizedShape, 1>::Shape Strides;
/// Memory space access
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
/// Parameters object
struct Params {
//
// Data members
//
/// Pointer to element type
Scalar *pointer;
/// Strides
Coord<4> stride;
//
// Methods
//
/// Constructs a parameters object
CUTLASS_HOST_DEVICE
Params(Scalar *_pointer = 0)
: pointer(_pointer),
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
/// Constructs a params object from a TensorRef
CUTLASS_HOST_DEVICE
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
};
//
// Data members
//
/// Parameters object
Params params;
//
// Methods
//
/// Constructs a store iterator
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator(
Params const &_params,
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
ThreadOffset offset_func = ThreadOffset())
: params(_params) {
// Compute initial offset for each thread
Coord<4> offset = offset_func();
params.pointer += (_block_offset + offset).template dot<int>(params.stride);
}
/// Stores a fragment
CUTLASS_DEVICE void store(Fragment const &fragment,
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
FragmentConstIterator frag_iterator(fragment);
// Iterate over each store
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
int idx = w + Iterations::kW * h;
int row = idx * 4;
Coord<4> sts_offset =
make_Coord(d, row, 0, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
Index _offset = params.stride.template dot<int>(sts_offset + offset);
Store<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::store(
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, 0)),
params.pointer,
_offset);
}
}
}
}
/// Increments store iterator to next tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &increment(int count = 1) {
params.pointer +=
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
return *this;
}
/// Increments to next tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator++() { return increment(); }
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator+=(int count) {
return increment(count);
}
/// Increments store iterator to previous tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &decrement(int count = 1) {
params.pointer -=
make_Coord(VectorizedShape::kD * count, 0, 0, 0).template dot<int>(params.stride);
return *this;
}
/// Increments to subsequent tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator--() { return decrement(); }
/// Decrements to previous tile
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &operator-=(int count) {
return decrement(count);
}
/// Stores a fragment and increments in the K dimension
CUTLASS_DEVICE Volta884ThreadblockMultiplicandStoreIterator &store_post_increment(
Fragment const &fragment, Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
store(fragment, offset);
return increment();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Iterator to load a fragment for each warp-level tile specialized for B.row_major
template <
/// Specifies threadblock tile shape
typename Tile_,
/// Specifies the warp tile shape
typename WarpTile_,
/// Specifies the number of participating warps
int WarpCount,
/// Specifies the delta between warp accesses along the outer dimension
typename WarpDelta_>
struct Volta884WarpMultiplicandLoadIterator<GemmOperand::kB,
MatrixLayout::kRowMajor,
Tile_,
WarpTile_,
WarpCount,
WarpDelta_> {
//
// Constant and type definitions
//
/// Identifies multiplicand of GEMM (A or B)
static GemmOperand::Kind const kOperand = GemmOperand::kB;
/// Specifies layout of data in source memory
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
/// Shape of thread-block multiplicand
typedef Tile_ Tile;
/// Shape of warp-tile matrix operation
typedef WarpTile_ WarpTile;
/// Hard-coded tile shape
typedef Shape<4, 32, 32> InterleavedTileShape;
/// Number of participating warps
static int const kWarpCount = WarpCount;
/// Delta between warp accumulator tiles along the outer dimension
typedef WarpDelta_ WarpDelta;
/// This implementation is specialized for 128b loads
static int const kAccessSize = 8;
/// Swizzled store iterator
struct ThreadOffset {
/// Computes the initial offset
CUTLASS_DEVICE Coord<4> operator()(int pointer_idx) const {
// Determine the warp's reading location within the SMEM tile
int warp_id = ((threadIdx.x >> 5) / WarpDelta::kW);
// This is an 8-element vector within one 32x32 tile
int lane_id = (threadIdx.x & 0x1f);
int vec_row = (lane_id >> 4);
int vec_col = ((lane_id & 8) >> 3);
int tile_row = pointer_idx * 2 + vec_row;
// Column rotation function
int t_col = (vec_col * 4);
if (pointer_idx == 1 || (WarpDelta::kH > 1 && (warp_id & 1))) {
vec_row |= 2;
}
t_col = t_col | ((lane_id & 3) ^ vec_row);
Coord<4> offset = make_Coord(0, warp_id * 2 + tile_row, t_col, 0);
return offset;
}
};
/// Projects the threadblock tile
typedef typename GemmMultiplicandTraits<Tile_, kOperand, kLayout>::Shape OperandShape;
/// Stored tile has a structure designed for efficient MIO storing and loading
typedef Shape<(OperandShape::kH >> 2), // one 3D tile per four elements in the K dimension
(OperandShape::kW >> 4), // four rows of SMEM per 64xK tile
kAccessSize, // Eight banks of MIO
kAccessSize>
VectorizedShape; // 128b stores
/// Delta between accesses
typedef typename platform::conditional<WarpDelta::kH == 1,
Shape<1, 0, 0, 0>,
Shape<1, 2 * WarpDelta::kH, 0, 0> >::type Delta;
/// Number of iterations
typedef Shape<1, WarpTile::kH / InterleavedTileShape::kH, 1, 1> Iterations;
/// Source tile traits
typedef TileTraits<VectorizedShape, Delta, Iterations, ThreadOffset, kAccessSize> Traits;
/// Scalar type
typedef half Scalar;
/// Index type
typedef int Index;
/// Index type
typedef int LongIndex;
//
// Derived types
//
/// Tensor reference
typedef TensorRef<Scalar, 4> TensorRef;
/// Predicate vector
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
/// Fragment definition
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
/// Elements loaded by one instruction
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
/// The fragment iterator.
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
/// The fragment const iterator.
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
/// Strides into expected SMEM tile
typedef typename ShapeStrides<VectorizedShape, 1>::Shape Strides;
/// Memory space access
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
/// Number of SMEM read pointers needed
static int const kPointerCount = (WarpDelta::kH == 1 ? 2 : 1);
/// Parameters object
struct Params {
//
// Data members
//
/// Pointer to element type
Scalar const *pointer;
/// Strides
Coord<4> stride;
//
// Methods
//
/// Constructs a parameters object
CUTLASS_HOST_DEVICE
Params(Scalar const *_pointer = 0)
: pointer(_pointer),
stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) {}
/// Constructs a params object from a TensorRef
CUTLASS_HOST_DEVICE
Params(TensorRef const &ref): pointer(ref.data()), stride(make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)) { }
};
//
// Data members
//
/// Pointer to element type
Scalar const *pointer[kPointerCount];
/// Strides
Coord<4> stride;
//
// Methods
//
/// Constructs a load iterator
CUTLASS_DEVICE Volta884WarpMultiplicandLoadIterator(
Params const &_params,
Coord<4> const &_block_offset = make_Coord(0, 0, 0, 0),
ThreadOffset offset_func = ThreadOffset())
: stride(_params.stride) {
CUTLASS_PRAGMA_UNROLL
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
Coord<4> offset = offset_func(ptr_idx);
pointer[ptr_idx] = _params.pointer + (_block_offset + offset).template dot<int>(stride);
}
}
/// Stores a fragment
CUTLASS_DEVICE void load(Fragment &fragment,
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) const {
FragmentIterator frag_iterator(fragment);
// Iterate over each load
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
// Pointers mapped to Iterations::kH dimension
Scalar const *_pointer = pointer[(kPointerCount == 2 ? h : 0)];
Coord<4> lds_offset =
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
Load<typename Fragment::Element, VectorizedShape::kC, kMemorySpace>::load(
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)),
_pointer,
stride.template dot<int>(lds_offset + offset));
}
}
}
}
/// Loads a fragment and increments to next K-index
CUTLASS_DEVICE void load_post_increment(Fragment &fragment,
Coord<4> const &offset = make_Coord(0, 0, 0, 0)) {
load(fragment, offset);
CUTLASS_PRAGMA_UNROLL
for (int ptr_idx = 0; ptr_idx < kPointerCount; ++ptr_idx) {
pointer[ptr_idx] += make_Coord(1, 0, 0, 0).template dot<int>(stride);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,629 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties for GEMM targeting Volta's mma.sync instruction
DO NOT INCLUDE THIS FILE DIRECTLY.
This file is intended to be included by <cutlass/gemm/volta884_shared_tile.h> and defines
partial specializations for templates specified therein.
*/
#pragma once
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Partial specializations for FP32 accumulator layouts
//
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue shared memory store iterator specialized for Volta's mma.sync.FP32 layout
template <
/// Shape of warp-level GEMM
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// Data type of accumulator elements
typename Scalar_>
struct Volta884EpilogueSharedStoreIterator<WarpGemmTile_, WarpDelta_, Scalar_, float> {
/// Warp-scoped GEMM tile size
typedef WarpGemmTile_ WarpGemmTile;
/// Tiling of warp elements across threadblock
typedef WarpDelta_ WarpDelta;
/// Scalar data type
typedef Scalar_ Scalar;
/// Accumulator data type (and layout)
typedef float Accumulator;
/// Index type
typedef int Index;
/// Index type
typedef int LongIndex;
// Host-side params
struct Params {};
/// Access size
static int const kAccessSize = 1;
/// Skew elements to ensure conflict free stores
static int const kSkew = 2;
/// Shape of one interleaved mma.sync tile
typedef Shape<4, 32, 32> MmaTileShape;
/// Four element fragment
typedef Shape<WarpGemmTile::kW / MmaTileShape::kW, 1, 4, 1> Iterations;
/// Delta separated by two elements
typedef Shape<MmaTileShape::kW * WarpDelta::kW, 1, 2, 1> Delta;
//
// Dependent types
//
/// Predicate vector
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
/// Memory space access
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
/// Fragment definition
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
/// Elements loaded by one instruction
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
/// The fragment iterator.
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
/// The fragment const iterator.
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
/// Tensor reference type
typedef TensorRef<Scalar, 4> TensorRef;
//
// Data members
//
/// Base pointer to SMEM allocation
Scalar *pointer;
/// Stride in shared memory
Coord<4> strides;
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
Volta884EpilogueSharedStoreIterator(Params const &_params, TensorRef const &ref)
: pointer(ref.data()), strides(make_Coord(1, WarpDelta::kW * WarpGemmTile::kW + kSkew, 1, 1)) {
int warp_id = (threadIdx.x / kWarpSize);
int lane_id = (threadIdx.x % kWarpSize);
Coord<4> warp_idx = make_Coord(0, warp_id / WarpDelta::kW, warp_id % WarpDelta::kW, 0);
Coord<4> warp_base = warp_idx * make_Coord(0, 4, MmaTileShape::kW, 0);
Coord<4> thread_idx = make_Coord(0,
(((lane_id >> 1) & 4) | (lane_id & 2)) >> 1,
(lane_id & 1) | ((lane_id >> 1) & 8) | ((lane_id << 2) & 16),
0);
int offset = strides.template dot<int>(warp_base + thread_idx);
pointer += offset;
}
/// Store to the epilogue tile.
CUTLASS_DEVICE
void store(Fragment const &fragment) const {
FragmentConstIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
Coord<4> coord =
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
int _offset = coord.template dot<int>(strides);
Store<typename Fragment::Element, kAccessSize, kMemorySpace>::store(
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, 0)), pointer,
_offset);
}
}
}
}
/// Stores to the epilogue tile - this iterator does not advance, so increment is null.
CUTLASS_DEVICE
void store_post_increment(Fragment const &fragment) { store(fragment); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue shared memory load iterator specialized for Volta's mma.sync.FP32 layout
template <
/// Shape of warp-level GEMM
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// Data type of accumulator elements
typename Scalar_,
/// Number of elements loaded per access
int AccessSize_>
struct Volta884EpilogueSharedLoadIterator<WarpGemmTile_, WarpDelta_, Scalar_, AccessSize_, float> {
/// Warp-scoped GEMM tile size
typedef WarpGemmTile_ WarpGemmTile;
/// Tiling of warp elements across threadblock
typedef WarpDelta_ WarpDelta;
/// Scalar data type
typedef Scalar_ Scalar;
/// Accumulator data type (and layout)
typedef float Accumulator;
/// Index type
typedef int Index;
/// Index type
typedef int LongIndex;
/// Number of elements accessed at once
static int const kAccessSize = AccessSize_;
/// Shape of one interleaved mma.sync tile
typedef Shape<4, 32, 32> MmaTileShape;
/// Total participating warps
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
/// Total participating threads
static int const kThreadCount = kWarpCount * kWarpSize;
/// Skew elements
static int const kSkew = 2;
/// This tile is to be strip-mined with a swizzling function
typedef Shape<2 * WarpDelta::kH, 2, WarpGemmTile::kW * WarpDelta::kW, 1> Tile;
/// Number of iterations
typedef Shape<2 * WarpDelta::kH,
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
1>
Iterations;
/// Delta between accesses
typedef Shape<2, 1, kThreadCount, 1> Delta;
//
// Derived quantities
//
/// Predicate vector
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
/// Fragment of elements to load
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
/// Elements loaded by one instruction
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
/// The fragment iterator.
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
/// The fragment const iterator.
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
static_assert(!(kSkew % kAccessSize), "Access size must have compatible alignment with skew");
/// Memory space access
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
/// Tensor reference type
typedef TensorRef<Scalar, 4> TensorRef;
/// Host-side params
struct Params {};
//
// Data members
//
/// Pointer
Scalar const *pointer;
/// Strides
Coord<4> strides;
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
Volta884EpilogueSharedLoadIterator(Params const &_params, TensorRef const &ref)
: pointer(ref.data()),
strides(make_Coord((WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
(WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
kAccessSize,
1)) {
// strip-mine this tile
int tid = threadIdx.x;
int residual_w = (tid / (Tile::kW));
int offset_w = (tid % (Tile::kW));
int offset_h = (residual_w % Tile::kH);
int offset_d = (residual_w / Tile::kH);
Coord<4> offset = make_Coord(offset_d * Delta::kW, offset_h * Delta::kH, offset_w, 0);
pointer += strides.template dot<int>(offset);
}
/// Loads a fragment from the epilogue tile.
CUTLASS_DEVICE
void load(Fragment &fragment) const {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
Coord<4> coord =
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kW);
int _offset = coord.template dot<int>(strides);
Load<typename Fragment::Element, kAccessSize, kMemorySpace>::load(
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), pointer, _offset);
}
}
}
}
/// Loads a fragment - iterator does not actually advance, so increment operation is null.
CUTLASS_DEVICE
void load_post_increment(Fragment &fragment) { load(fragment); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Partial specializations for FP16 accumulator layouts
//
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue shared memory store iterator specialized for Volta's mma.sync.FP16 layout
template <
/// Shape of warp-level GEMM
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// Data type of accumulator elements
typename Scalar_>
struct Volta884EpilogueSharedStoreIterator<WarpGemmTile_, WarpDelta_, Scalar_, half> {
/// Warp-scoped GEMM tile size
typedef WarpGemmTile_ WarpGemmTile;
/// Tiling of warp elements across threadblock
typedef WarpDelta_ WarpDelta;
/// Scalar data type
typedef Scalar_ Scalar;
/// Accumulator data type (and layout)
typedef half Accumulator;
/// Index type
typedef int Index;
/// Index type
typedef int LongIndex;
/// Host-side params
struct Params {};
/// Dimensions of contiguous 32x32x4 Volta's mma.sync tile
typedef Shape<4, 32, 32> MmaTileShape;
/// Accumulator fragment
typedef Shape<WarpGemmTile::kW / MmaTileShape::kW, 1, 2, 1> Iterations;
/// Delta separated by two elements
typedef Shape<MmaTileShape::kW * WarpDelta::kW, 1, 4, 1> Delta;
/// Access size
static int const kAccessSize = 1;
/// Skew elements to ensure conflict free stores
static int const kSkew = 2;
/// Tensor reference type
typedef TensorRef<Scalar, 4> TensorRef;
//
// Dependent types
//
/// Predicate vector
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
/// Memory space access
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
/// Fragment definition
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
/// Elements loaded by one instruction
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
/// The fragment iterator.
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
/// The fragment const iterator.
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
//
// Data members
//
/// Base pointer to SMEM allocation
Scalar *pointer;
/// Stride in shared memory
Coord<4> strides;
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
Volta884EpilogueSharedStoreIterator(Params const &_params, TensorRef const &ref)
: pointer(ref.data()), strides(make_Coord(1, WarpGemmTile::kW * WarpDelta::kW + kSkew, 1, 1)) {
int warp_id = (threadIdx.x / kWarpSize);
int lane_id = (threadIdx.x % kWarpSize);
int quad_id = (lane_id >> 2);
int quadpair_id = (quad_id & 0x3);
int quadpair_row = (quadpair_id & 1);
int quadpair_col = (quadpair_id >> 1);
int quad_hilo = (quad_id >> 2) & 1;
int thread_row_offset = (quadpair_row * 2 + quad_hilo) * 8 + (lane_id & 3);
int thread_col_offset = quadpair_col;
Coord<4> thread_idx = make_Coord(0, thread_col_offset, thread_row_offset, 0);
Coord<4> warp_base = make_Coord(0, warp_id / WarpDelta::kW, warp_id % WarpDelta::kW, 0) *
make_Coord(0, 2, kWarpSize, 0);
Coord<4> offset = warp_base + thread_idx;
pointer += strides.template dot<int>(offset);
}
/// Store to the epilogue tile.
CUTLASS_DEVICE
void store(Fragment const &fragment) const {
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
Coord<4> coord =
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kC);
int _offset = coord.template dot<int>(strides);
Store<typename Fragment::Element, kAccessSize, kMemorySpace>::store(
reinterpret_cast<AccessType const &>(fragment[w + Iterations::kW * d]),
pointer,
_offset);
}
}
}
}
/// Stores to the epilogue tile - this iterator does not advance, so increment is null.
CUTLASS_DEVICE
void store_post_increment(Fragment const &fragment) { store(fragment); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Epilogue shared memory load iterator specialized for Volta's mma.sync.FP16 layout
template <
/// Shape of warp-level GEMM
typename WarpGemmTile_,
/// Tiling of warp accumulator elements
typename WarpDelta_,
/// Data type of accumulator elements
typename Scalar_,
/// Number of elements loaded per access
int AccessSize_>
struct Volta884EpilogueSharedLoadIterator<WarpGemmTile_, WarpDelta_, Scalar_, AccessSize_, half> {
/// Warp-scoped GEMM tile size
typedef WarpGemmTile_ WarpGemmTile;
/// Tiling of warp elements across threadblock
typedef WarpDelta_ WarpDelta;
/// Scalar data type
typedef Scalar_ Scalar;
/// Accumulator data type (and layout)
typedef half Accumulator;
/// Number of elements accessed at once
static int const kAccessSize = AccessSize_;
/// Shape of one interleaved mma.sync tile
typedef Shape<4, 32, 32> MmaTileShape;
/// This tile is to be strip-mined with a swizzling function
typedef Shape<1, 2 * WarpDelta::kH, WarpGemmTile::kW * WarpDelta::kW / kAccessSize, kAccessSize>
Tile;
/// Index type
typedef int Index;
/// Index type
typedef int LongIndex;
/// Total participating warps
static int const kWarpCount = ShapeCount<WarpDelta>::kCount;
/// Number of participating threads
static int const kThreadCount = kWarpSize * kWarpCount;
/// Number of iterations
typedef Shape<1,
(kThreadCount >= Tile::kW ? Tile::kH / (kThreadCount / Tile::kW) : Tile::kH),
(kThreadCount >= Tile::kW ? 1 : Tile::kW / kThreadCount),
1>
Iterations;
/// Delta between thread-level accesses
typedef typename platform::conditional<kThreadCount >= Tile::kW,
Shape<1, (kThreadCount / Tile::kW), 1, 1>,
Shape<1, 1, kThreadCount, 1> >::type Delta;
//
// Derived quantities
//
/// Predicate vector
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
/// Fragment of elements to load
typedef Fragment<Scalar, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
/// Elements loaded by one instruction
typedef typename Vectorize<Scalar, kAccessSize>::Type AccessType;
/// The fragment iterator.
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
/// The fragment const iterator.
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
/// Skew elements
static int const kSkew = 2;
static_assert(!(kSkew % kAccessSize), "Access size must have compatible alignment with skew");
/// Memory space access
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGeneric;
/// Tensor reference type
typedef TensorRef<Scalar, 4> TensorRef;
/// Host-side params
struct Params {};
//
// Data members
//
/// Pointer
Scalar const *pointer;
/// Strides
Coord<4> strides;
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
Volta884EpilogueSharedLoadIterator(Params const &_params, TensorRef const &ref)
: pointer(ref.data()),
strides(make_Coord(2 * (WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
(WarpDelta::kW * WarpGemmTile::kW + kSkew) * kAccessSize,
kAccessSize,
1)) {
// strip-mine this tile
Coord<4> offset = make_Coord(0, threadIdx.x / Tile::kW, threadIdx.x % Tile::kW, 0);
pointer += strides.template dot<int>(offset);
}
/// Loads a fragment from the epilogue tile.
CUTLASS_DEVICE
void load(Fragment &fragment) const {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
Coord<4> coord =
make_Coord(d, h, w, 0) * make_Coord(Delta::kD, Delta::kH, Delta::kW, Delta::kW);
int _offset = coord.template dot<int>(strides);
Load<typename Fragment::Element, kAccessSize, kMemorySpace>::load(
reinterpret_cast<AccessType &>(fragment[w + Iterations::kW * h]), pointer, _offset);
}
}
}
}
/// Loads a fragment - iterator does not actually advance, so increment operation is null.
CUTLASS_DEVICE
void load_post_increment(Fragment &fragment) { load(fragment); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -106,7 +106,7 @@ struct WmmaGemmEpilogueTraitsHelper {
// The number of scalars per LDS.
GemmConfig_::kScalarsPerLdsD,
// this parameter helps with swizzling when accum is fp32 and output is fp16
sizeof(Accumulator_) / sizeof(typename GemmConfig_::ScalarD)
int(sizeof(Accumulator_)) / int(sizeof(typename GemmConfig_::ScalarD))
>
SharedLoadTileTraits;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -103,7 +103,7 @@ struct WmmaGemmGlobalIteratorCd : public GemmGlobalIteratorCd<TileTraits_, Index
Index epilogue_stride_w,
Index epilogue_delta_w) {
// The pointer.
this->pointer = pointer;
BaseParams::pointer = pointer;
// Stride between GEMMs
this->stride_d = batch_stride;
// Setup the base stride. One "group of threads" per column.

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -85,7 +85,10 @@ struct WmmaGemmMultiplyAdd {
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Iterations::kH; ++j) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Iterations::kW; ++i) {
// The input elements.
ElementA const& elt_a = a[i];
@ -164,7 +167,10 @@ struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Iterations::kH; ++j) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Iterations::kW; ++i) {
// The input elements.
ElementA const& elt_a = a[i];
@ -249,7 +255,10 @@ struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Iterations::kH; ++j) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Iterations::kW; ++i) {
// The input elements.
ElementA const& elt_a = a[i];
@ -329,7 +338,10 @@ struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Iterations::kH; ++j) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Iterations::kW; ++i) {
// The input elements.
ElementA const& elt_a = a[i];

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -38,9 +38,13 @@ namespace cutlass {
template <typename InputIterator, typename Fragment>
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) {
typename InputIterator::FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
if (iterator.valid(d, h, w, c)) {
iterator.load_element(reinterpret_cast<typename InputIterator::AccessType &>(
@ -69,9 +73,13 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme
template <typename OutputIterator, typename Fragment>
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) {
typename OutputIterator::FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
if (iterator.valid(d, h, w, c)) {
iterator.store_element(reinterpret_cast<typename OutputIterator::AccessType &>(

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -0,0 +1,90 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
//#include <stdexcept>
#include "cutlass/cutlass.h"
//#include "tools/util/reference/device/kernel/tensor_foreach.h"
namespace cutlass {
namespace layout {
namespace thread {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines several helpers
namespace detail {
/// Helper to perform for-each operation
template <typename Func, int Rank, int RankRemaining>
struct TensorForEachHelper {
/// Index of the active rank
static int const kActiveRank = Rank - RankRemaining - 1;
/// Constructor for general rank
CUTLASS_DEVICE TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord) {
for (int i = 0; i < size.at(kActiveRank); ++i) {
coord[kActiveRank] = i;
TensorForEachHelper<Func, Rank, RankRemaining - 1>(func, size, coord);
}
}
};
/// Helper to perform for-each operation
template <typename Func, int Rank>
struct TensorForEachHelper<Func, Rank, 0> {
/// Index of the active rank
static int const kActiveRank = Rank - 1;
/// Constructor for fastest chaning rank
CUTLASS_DEVICE TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord) {
for (int i = 0; i < size.at(kActiveRank); ++i) {
coord[kActiveRank] = i;
func(coord);
}
}
};
} // namespace detail
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Iterates over the index space of a tensor
template <typename Func, int Rank, typename Params>
struct TensorForEach {
/// Constructor performs the operation.
CUTLASS_DEVICE TensorForEach(Coord<Rank> size, Params params = Params()) {
Func func(params);
Coord<Rank> coord;
detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, size, coord);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace layout
} // namespace cutlass

View File

@ -0,0 +1,300 @@
/***************************************************************************************************
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Basic copy routines for tensor views
*/
#pragma once
#include "cutlass/fragment.h"
#include "cutlass/layout/thread/tensor_foreach.h"
#include "cutlass/tensor_view.h"
namespace cutlass {
namespace layout {
namespace thread {
/// Define a functor that performs a copy operation on a tensor.
template <typename View_dst, typename View_src>
struct CopyFunc {
/// Coordinate of index space
typedef typename View_dst::TensorCoord TensorCoord;
View_dst dst;
View_src src;
/// Constructor
CUTLASS_DEVICE
CopyFunc(View_dst dst, View_src src) : dst(dst), src(src) {}
/// copy function
CUTLASS_DEVICE
void operator()(TensorCoord const& coord) {
dst.at(coord) = src.at(coord); // uses tensor view's map()
}
};
template <typename T_dst, typename T_src, int rank, typename MapFunc_dst, typename MapFunc_src>
struct Copy {
CUTLASS_DEVICE void copy(cutlass::TensorView<T_dst, rank, MapFunc_dst> dst,
cutlass::TensorView<T_src, rank, MapFunc_src> src) {
// Define a functor called by TensorForEach<>
typedef CopyFunc<cutlass::TensorView<T_dst, rank, MapFunc_dst>,
cutlass::TensorView<T_src, rank, MapFunc_src> >
CopyFunc;
// Instantiate on device with TensorViews
CopyFunc copy_func(dst, src);
// Invoke device-side for-each computation on the tensor
cutlass::layout::thread::TensorForEach<CopyFunc,
rank, // View::kRank
CopyFunc>(src.size(), copy_func);
}
};
template <int rank>
struct Copy<half, half, rank, cutlass::MatrixLayout::RowMajor, cutlass::MatrixLayout::RowMajor> {
CUTLASS_DEVICE void copy(cutlass::TensorView<half, rank, cutlass::MatrixLayout::RowMajor> dst,
cutlass::TensorView<half, rank, cutlass::MatrixLayout::RowMajor> src) {
bool isPacked = dst.isPacked() && src.isPacked();
if (isPacked) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < src.capacity(); ++i) {
dst.at(i) = src.at(i);
}
} else {
typedef CopyFunc<cutlass::TensorView<half, rank, cutlass::MatrixLayout::RowMajor>,
cutlass::TensorView<half, rank, cutlass::MatrixLayout::RowMajor> >
CopyFunc;
// Instantiate on device with TensorViews
CopyFunc copy_func(dst, src);
// Invoke device-side for-each computation on the tensor
cutlass::layout::thread::TensorForEach<CopyFunc,
rank, // View::kRank
CopyFunc>(src.size(), copy_func);
}
}
};
/// hgemm swizzle
/// Transform a fragment.
template <>
struct Copy<half, half, 2, cutlass::MatrixLayout::RowMajor, cutlass::MatrixLayout::ColumnMajor> {
CUTLASS_DEVICE void copy(
cutlass::TensorView<half, 2, cutlass::MatrixLayout::RowMajor> dst,
cutlass::TensorView<half, 2, cutlass::MatrixLayout::ColumnMajor> src) {
// Expose src/dst as int arrays.
int const* src_int = reinterpret_cast<int const*>(src.const_ref().data());
int* dst_int = reinterpret_cast<int*>(dst.ref().data());
int kD = src.size(0);
int kDhw = src.size(0) * src.size(1);
// Transpose the data.
// CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < kD; ++d) {
// The indices to read two consecutive "rows".
int const i0 = 2 * d + 0;
int const i1 = 2 * d + 1;
int a0 = src_int[i0];
int a1 = src_int[i1];
int b0, b1;
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1));
// The indices to store with "strides".
int const j0 = 0 * (kDhw / 2) + d;
int const j1 = 1 * (kDhw / 2) + d;
dst_int[j0] = b0;
dst_int[j1] = b1;
}
}
};
/// igemm swizzle
/// Transform a fragment.
template <>
struct Copy<int8_t,
int8_t,
2,
cutlass::MatrixLayout::RowMajor,
cutlass::MatrixLayout::ColumnMajor> {
CUTLASS_DEVICE void copy(
cutlass::TensorView<int8_t, 2, cutlass::MatrixLayout::RowMajor> dst,
cutlass::TensorView<int8_t, 2, cutlass::MatrixLayout::ColumnMajor> src) {
// Expose src/dst as int arrays.
int const* src_int = reinterpret_cast<int const*>(src.const_ref().data());
int* dst_int = reinterpret_cast<int*>(dst.ref().data());
int kD = src.size(0);
int kH = src.size(1);
int kWc = src.stride(0);
int kHwc = kH * kWc;
// Transpose the data.
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < kH / 4; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < kWc / 4; ++w) {
int const i0 = d * (kHwc / 4) + (4 * h + 0) * (kWc / 4) + w;
int const i1 = d * (kHwc / 4) + (4 * h + 1) * (kWc / 4) + w;
int const i2 = d * (kHwc / 4) + (4 * h + 2) * (kWc / 4) + w;
int const i3 = d * (kHwc / 4) + (4 * h + 3) * (kWc / 4) + w;
int a0 = src_int[i0];
int a1 = src_int[i1];
int a2 = src_int[i2];
int a3 = src_int[i3];
int b0, b1, b2, b3, c0;
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0));
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3));
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0));
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3));
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0));
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
dst_int[i0] = b0;
dst_int[i1] = b1;
dst_int[i2] = b2;
dst_int[i3] = b3;
}
}
}
}
};
template <typename Shape,
int Rank,
typename DstType,
typename DstLayout,
typename SrcType,
typename SrcLayout>
struct Transform {
typedef Fragment<DstType, ShapeCount<Shape>::kCount> DstFragment;
typedef Fragment<SrcType, ShapeCount<Shape>::kCount> SrcFragment;
/// The input fragment.
typedef SrcFragment InputFragment;
/// The output fragment.
typedef DstFragment OutputFragment;
CUTLASS_DEVICE void transform(SrcFragment& src, DstFragment& dst) {
cutlass::TensorView<DstType, Rank, DstLayout> dstView(
&dst[0], // pointer to base of matrix in device memory
cutlass::make_Coord(Shape::kD, 1), // stride vector
cutlass::make_Coord(Shape::kD,
Shape::kH * Shape::kW) // bounds of matrix
);
cutlass::TensorView<SrcType, Rank, SrcLayout> srcView(
&src[0], // pointer to base of matrix in device memory
cutlass::make_Coord(Shape::kD, 1), // stride vector
cutlass::make_Coord(Shape::kD,
Shape::kH * Shape::kW) // bounds of matrix
);
cutlass::layout::thread::Copy<DstType, SrcType, Rank, DstLayout, SrcLayout> Transformer;
Transformer.copy(dstView, srcView);
}
};
template <typename Shape, int Rank, typename DstLayout, typename SrcLayout>
struct Transform<Shape, Rank, half, DstLayout, half, SrcLayout> {
typedef Fragment<half, ShapeCount<Shape>::kCount> DstFragment;
typedef Fragment<half, ShapeCount<Shape>::kCount> SrcFragment;
/// The input fragment.
typedef SrcFragment InputFragment;
/// The output fragment.
typedef DstFragment OutputFragment;
CUTLASS_DEVICE void transform(SrcFragment& src, DstFragment& dst) {
cutlass::TensorView<half, Rank, DstLayout> dstView(
&dst[0], // pointer to base of matrix in device memory
cutlass::make_Coord(Shape::kD, 1), // stride vector
cutlass::make_Coord(Shape::kD,
Shape::kH * Shape::kW) // bounds of matrix
);
cutlass::TensorView<half, Rank, SrcLayout> srcView(
&src[0], // pointer to base of matrix in device memory
cutlass::make_Coord(Shape::kD, 1), // stride vector
cutlass::make_Coord(Shape::kD,
Shape::kH * Shape::kW) // bounds of matrix
);
cutlass::layout::thread::Copy<half, half, Rank, DstLayout, SrcLayout> Transformer;
Transformer.copy(dstView, srcView);
}
};
template <typename Shape, int Rank, typename DstLayout, typename SrcLayout>
struct Transform<Shape, Rank, int8_t, DstLayout, int8_t, SrcLayout> {
typedef Fragment<int8_t, ShapeCount<Shape>::kCount> DstFragment;
typedef Fragment<int8_t, ShapeCount<Shape>::kCount> SrcFragment;
/// The input fragment.
typedef SrcFragment InputFragment;
/// The output fragment.
typedef DstFragment OutputFragment;
CUTLASS_DEVICE void transform(SrcFragment& src, DstFragment& dst) {
cutlass::TensorView<int8_t, Rank, DstLayout> dstView(
&dst[0], // pointer to base of matrix in device memory
cutlass::make_Coord(Shape::kW * Shape::kC, 1), // stride vector
cutlass::make_Coord(Shape::kD,
Shape::kH) // bounds of matrix
);
cutlass::TensorView<int8_t, Rank, SrcLayout> srcView(
&src[0], // pointer to base of matrix in device memory
cutlass::make_Coord(Shape::kW * Shape::kC, 1), // stride vector
cutlass::make_Coord(Shape::kD,
Shape::kH) // bounds of matrix
);
cutlass::layout::thread::Copy<int8_t, int8_t, Rank, DstLayout, SrcLayout> Transformer;
Transformer.copy(dstView, srcView);
}
};
} // namespace thread
} // namespace layout
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -66,6 +66,11 @@ struct Load {
dst = *reinterpret_cast<AccessType const*>(pointer + offset);
}
/// The clear function.
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
dst = 0;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -80,6 +85,11 @@ struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
reinterpret_cast<uint16_t&>(dst) = reinterpret_cast<uint16_t const*>(&pointer[offset])[0];
}
/// The clear function.
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
dst = uint16_t(0);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -94,6 +104,10 @@ struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_
dst.registers[0] = reinterpret_cast<uint32_t const*>(&pointer[offset])[0];
}
/// The clear function.
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
dst.registers[0] = uint32_t(0);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -109,6 +123,13 @@ struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_
dst.registers[0] = tmp.x;
dst.registers[1] = tmp.y;
}
/// The clear function.
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
uint2 const zero = make_uint2(0, 0);
dst.registers[0] = zero.x;
dst.registers[1] = zero.y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -124,6 +145,13 @@ struct Load<double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 1
dst[0] = tmp.x;
dst[1] = tmp.y;
}
/// The clear function.
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
double2 zero = make_double2(0, 0);
dst[0] = zero.x;
dst[1] = zero.y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -145,6 +173,15 @@ struct Load<half, 8, Memory_, FragmentElementType::kScalar, half, kStride, 16> {
dst.registers[2] = tmp.x;
dst.registers[3] = tmp.y;
}
/// The clear function.
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
int2 zero = make_int2(0,0);
dst.registers[0] = zero.x;
dst.registers[1] = zero.y;
dst.registers[2] = zero.x;
dst.registers[3] = zero.y;
}
};
#endif
@ -164,6 +201,15 @@ struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_
dst.registers[2] = tmp.z;
dst.registers[3] = tmp.w;
}
/// The clear function.
static CUTLASS_HOST_DEVICE void clear(AccessType& dst) {
uint4 zero = make_uint4(0, 0, 0, 0);
dst.registers[0] = zero.x;
dst.registers[1] = zero.y;
dst.registers[2] = zero.z;
dst.registers[3] = zero.w;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -139,7 +139,6 @@ struct MatrixCoord : public Coord<2, int> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines data layouts of various matrix formats usable by TensorRef and other classes.
//
// The following define classes satisfying the TensorRefMapFunc concept. These must support the
@ -367,6 +366,12 @@ struct MatrixTransform {
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Tensor layout
namespace TensorLayout {
enum Kind { kNHWC, kNCHW };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -186,9 +186,9 @@ struct PredicateVector {
CUTLASS_HOST_DEVICE
ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
///
/// Copy ctor
CUTLASS_HOST_DEVICE
ConstIterator(PredicateVector const &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {}
/// Pre-increment
CUTLASS_HOST_DEVICE
@ -197,6 +197,13 @@ struct PredicateVector {
return *this;
}
/// Increment
CUTLASS_HOST_DEVICE
ConstIterator &operator+=(int offset) {
bit_ += offset;
return *this;
}
/// Pre-decrement
CUTLASS_HOST_DEVICE
ConstIterator &operator--() {
@ -204,6 +211,13 @@ struct PredicateVector {
return *this;
}
/// Decrement
CUTLASS_HOST_DEVICE
ConstIterator &operator-=(int offset) {
bit_ -= offset;
return *this;
}
/// Post-increment
CUTLASS_HOST_DEVICE
ConstIterator operator++(int) {
@ -220,6 +234,22 @@ struct PredicateVector {
return ret;
}
/// Iterator advances by some amount
CUTLASS_HOST_DEVICE
ConstIterator operator+(int offset) {
ConstIterator ret(*this);
ret.bit_ += offset;
return ret;
}
/// Iterator recedes by some amount
CUTLASS_HOST_DEVICE
ConstIterator operator-(int offset) {
ConstIterator ret(*this);
ret.bit_ -= offset;
return ret;
}
/// Returns true if iterators point to the same bit
CUTLASS_HOST_DEVICE
bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; }
@ -230,7 +260,15 @@ struct PredicateVector {
/// Dereferences iterator
CUTLASS_HOST_DEVICE
bool operator*() const { return vec_[bit_]; }
bool operator*() const { return vec_.at(bit_); }
/// Gets the bit at the pointed to location
CUTLASS_HOST_DEVICE
bool get() const { return vec_.at(bit_); }
/// Gets the bit at the pointed to location
CUTLASS_HOST_DEVICE
bool at() const { return vec_.at(bit_); }
};
/**
@ -252,7 +290,7 @@ struct PredicateVector {
/// Constructs an iterator from a PredicateVector
CUTLASS_HOST_DEVICE
Iterator(PredicateVector &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {}
/// Pre-increment
CUTLASS_HOST_DEVICE
@ -261,6 +299,13 @@ struct PredicateVector {
return *this;
}
/// Increment
CUTLASS_HOST_DEVICE
Iterator &operator+=(int offset) {
bit_ += offset;
return *this;
}
/// Pre-decrement
CUTLASS_HOST_DEVICE
Iterator &operator--() {
@ -268,6 +313,13 @@ struct PredicateVector {
return *this;
}
/// Decrement
CUTLASS_HOST_DEVICE
Iterator &operator-=(int offset) {
bit_ -= offset;
return *this;
}
/// Post-increment
CUTLASS_HOST_DEVICE
Iterator operator++(int) {
@ -284,6 +336,22 @@ struct PredicateVector {
return ret;
}
/// Iterator advances by some amount
CUTLASS_HOST_DEVICE
Iterator operator+(int offset) {
Iterator ret(*this);
ret.bit_ += offset;
return ret;
}
/// Iterator recedes by some amount
CUTLASS_HOST_DEVICE
Iterator operator-(int offset) {
ConstIterator ret(*this);
ret.bit_ -= offset;
return ret;
}
/// Returns true if iterators point to the same bit
CUTLASS_HOST_DEVICE
bool operator==(Iterator const &it) const { return bit_ == it.bit_; }
@ -294,11 +362,15 @@ struct PredicateVector {
/// Gets the bit at the pointed to location
CUTLASS_HOST_DEVICE
bool get() { return vec_[bit_]; }
bool get() { return vec_.at(bit_); }
/// Gets the bit at the pointed to location
CUTLASS_HOST_DEVICE
bool at() const { return vec_.at(bit_); }
/// Dereferences iterator
CUTLASS_HOST_DEVICE
bool operator*() const { return vec_[bit_]; }
bool operator*() const { return at(); }
/// Sets the bit at the pointed to location
CUTLASS_HOST_DEVICE

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -80,14 +80,16 @@ struct BatchedReduction {
typename Traits::ScalarA inRegs[Traits::maxInReg];
typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg];
#pragma unroll
for (int subTile = 0; subTile < tileSize; subTile += subTileSize) {
int tileOffset = subTileBase + subTileOffset;
// Init AccumRegs
#pragma unroll
for (int i = 0; i < Traits::ThreadShape::kW; i++)
AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(0.0f);
// Fetch c0
typename Traits::ScalarAccum c0[Traits::ThreadShape::kW];
#pragma unroll
for (int i = 0; i< Traits::ThreadShape::kW; i++)
c0[i] = static_cast<typename Traits::ScalarAccum>(params.d_c[tileOffset + i]);
@ -131,11 +133,13 @@ struct BatchedReduction {
template<bool ThreadShapeMultiple2>
CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output) {
if (ThreadShapeMultiple2 == true) {
#pragma unroll
for (int i = 0; i < Traits::ThreadShape::kW / 2; i++) {
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 2>(&accum[2 * i], &old[2 * i], &output[2 * i]);
}
}
else {
#pragma unroll
for (int i = 0; i < Traits::ThreadShape::kW; i++) {
functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 1>(&accum[i], &old[i], &output[i]);
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,6 +72,23 @@ struct Shape {
static int const kC = kC_;
};
/**
* @brief A Shape implementing \ref layout_concept describing the dimensions of a cube.
* @concept{layout_concept}
*/
template <int kH_, int kW_>
struct Shape<1, kH_, kW_, 1> {
/// The depth of the cube.
static int const kD = 1;
/// The height of the cube.
static int const kH = kH_;
/// The width of the cube.
static int const kW = kW_;
/// The number of scalars per element.
static int const kC = 1;
};
/**
* @brief Compute derived counted of a \ref layout_concept based class
*/

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -63,12 +63,7 @@ struct IdentityTensorMapFunc {
and assumptions about vectorizing memory accesses throughout CUTLASS. It also matches various
BLAS conventions in which only the "leading dimension" or most significant stride of a rank=2
matrix is provided.
This does affect the ability of constructing arbitrary "sparse" 2-D matrices in memory where all
stride elements are > 1. This can be overcome by defining a custom mapping function and a
StorageRank of 3 or more.
Examples:
(These examples use helpers for matrix layouts defined in cutlass/matrix_traits.h)
@ -85,7 +80,7 @@ struct IdentityTensorMapFunc {
TensorRef<int8_t, 2, MatrixLayout::ColumnMajorInterleaved<32> > C;
4. Defining a sparse matrix with arbitrary strides in each dimension
4. Defining a matrix with arbitrary strides in each dimension
struct ContiguousLayout {
@ -545,6 +540,10 @@ class TensorRef<Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_> {
CUTLASS_HOST_DEVICE
Storage * data() const { return ptr_; }
/// Returns the pointer to referenced data at the given coordinate
CUTLASS_HOST_DEVICE
Storage * data(TensorCoord const& coord) const { return ptr_ + offset(coord); }
/// Returns the stride of the tensor
CUTLASS_HOST_DEVICE
StorageCoord stride() const {

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -75,7 +75,7 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
LongIndex_> ConstTensorRef;
/// Base tensor reference
typedef Base TensorRef;
typedef Base TensorRef_t;
/// Storage type
typedef typename Base::Storage Storage;
@ -84,14 +84,14 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
typedef typename Base::Index Index;
/// Coordinate in logical tensor space
typedef typename TensorRef::TensorCoord TensorCoord;
typedef typename TensorRef_t::TensorCoord TensorCoord;
/// Coordinate in storage n-D array
typedef typename TensorRef::StorageCoord StorageCoord;
typedef typename TensorRef_t::StorageCoord StorageCoord;
/// Stride vector in storage coordinate space
/// Least significant stride is = 1 and not stored
typedef typename TensorRef::StrideVector StrideVector;
typedef typename TensorRef_t::StrideVector StrideVector;
/// TensorView of constant value
typedef TensorView<
@ -115,11 +115,8 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
/// Type used to compute the offset of an element to the base of a tensor
typedef typename Base::LongIndex Offset_t;
/// Base class
typedef TensorRef TensorRef_t;
/// TensorRef to const-valued type
typedef typename TensorRef::ConstTensorRef ConstTensorRef_t;
typedef typename TensorRef_t::ConstTensorRef ConstTensorRef_t;
private:
//
@ -195,14 +192,46 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
return true;
}
/// Returns a TensorRef pointing to the first element of the tensor.
/// Determines the order of dims of the tensor (e.g., CHW versus HWC)
CUTLASS_HOST_DEVICE
TensorRef ref() const {
return TensorRef(*this);
void getStrideOrder(int order[]) const {
for (int i = 0; i < Rank_; i++) order[i] = i;
// Bubble sort
for (int start = 0; start < Rank_ - 1; start++) {
for (int i = start; i < Rank_ - 1; i++) {
if (this->stride(order[i]) < this->stride(order[i + 1])) {
int temp = order[i];
order[i] = order[i + 1];
order[i + 1] = temp;
}
}
}
// post-condition: this->stride(ord[i]) >= this->stride(ord[i+1]) for i from [0,Rank_-2]
}
/// Determines if the values in the tensor are contiguous
CUTLASS_HOST_DEVICE
bool isPacked() const {
if (Rank_ <= 0) return true;
int ord[Rank_];
getStrideOrder(ord);
// first check if the slowest dimension has a stride of 1
if (this->stride(ord[Rank_ - 1]) != 1) return false;
// now check that there are no gaps between strides
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank_; i++)
if (this->stride(ord[i]) != this->stride(ord[i + 1]) * size_[ord[i + 1]]) return false;
return true;
}
/// Returns a TensorRef pointing to the first element of the tensor.
CUTLASS_HOST_DEVICE
TensorRef_t ref() const {
return TensorRef_t(*this);
}
/// Returns a TensorRef_t pointing to the first element of the tensor.
CUTLASS_HOST_DEVICE
ConstTensorRef const_ref() const {
return ConstTensorRef(*this);
}
@ -238,22 +267,22 @@ class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Ind
return result;
}
/// Returns a TensorRef offset by a given amount
/// Returns a TensorRef_t offset by a given amount
CUTLASS_HOST_DEVICE
TensorView& operator+=(TensorCoord const& b) {
this->add_pointer_offset(this->offset(b));
return *this;
}
/// Returns a TensorRef offset by a given amount
/// Returns a TensorRef_t offset by a given amount
CUTLASS_HOST_DEVICE
TensorView operator-(TensorCoord const& b) const {
TensorRef result(*this);
TensorRef_t result(*this);
result.add_pointer_offset(-this->offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
/// Returns a TensorRef_t offset by a given amount
CUTLASS_HOST_DEVICE
TensorView& operator-=(TensorCoord const& b) {
this->add_pointer_offset(-this->offset(b));

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -48,7 +48,7 @@ struct TileAllocation {
typedef Scalar_ Scalar;
/// The actual storage (may differ from the scalar type)
typedef typename StorageType<sizeof(Scalar)>::Type Storage;
typedef typename StorageType<int(sizeof(Scalar))>::Type Storage;
/// Size of the allocation in units of scalars
typedef Shape_ Shape;
@ -165,4 +165,62 @@ struct ZipTileAllocation {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Manages a pair of tile allocations as if they are one allocation
template <typename First_, typename Second_, typename Third_>
struct ZipTileAllocationTriple {
//
// Type definitions
//
/// First tensor allocation
typedef First_ First;
/// Second tensor allocation
typedef Second_ Second;
/// meta data tensor allocation
typedef Third_ Third;
/// Defines the tensor reference for this allocation
typedef Zip3TensorRef<typename First::TensorRef,
typename Second::TensorRef,
typename Third::TensorRef> TensorRef;
/// Defines the tensor reference for this allocation
typedef Zip3TensorRef<typename First::ConstTensorRef,
typename Second::ConstTensorRef,
typename Third::ConstTensorRef>
ConstTensorRef;
//
// Data members
//
/// First tensor allocation
First first;
/// Second tensor allocation
Second second;
/// meta data tensor
Third third;
//
// Methods
//
/// Returns a TensorRef object pointing to the data
CUTLASS_DEVICE
TensorRef reference() {
return TensorRef(first.reference(), second.reference(), third.reference());
}
/// Returns a TensorRef object pointing to the data
CUTLASS_DEVICE
ConstTensorRef reference() const {
return ConstTensorRef(first.reference(), second.reference(), third.reference());
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -293,12 +293,10 @@ struct TileIteratorBase {
stride_d = _stride_d;
stride_h = _stride_h;
stride_w = _stride_w;
inc_w = stride_w * Delta::kW;
inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
inc_d = stride_h * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) -
stride_w * Delta::kW * (Iterations::kW - 1);
inc_advance = 0;
if (kAdvance == IteratorAdvance::kH) {
@ -740,9 +738,13 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
template <typename Fragment, typename PredicateIterator>
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
if (*pred_it) {
load_element(
@ -789,8 +791,11 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
template <typename Fragment>
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
}
@ -1076,7 +1081,6 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
ThreadOffset thread_offset_func = ThreadOffset())
: params(_params), stage(0) {
thread_offset = thread_offset_func();
params.pointer += (block_offset[0] + thread_offset[0]) * params.stride_d +
(block_offset[1] + thread_offset[1]) * params.stride_h +
(block_offset[2] + thread_offset[2]) * params.stride_w;
@ -1148,10 +1152,13 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
template <typename Fragment, typename PredicateIterator>
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) {
FragmentConstIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
if (*pred_it) {
store_element(
@ -1213,9 +1220,13 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int d = 0; d < Iterations::kD; ++d) {
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
if (*pred_it) {
load_element(
@ -1262,8 +1273,11 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
template <typename Fragment>
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
FragmentIterator frag_iterator(fragment);
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < Iterations::kC; ++c) {
load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -272,6 +272,9 @@ struct PredicatedTileLoadStream : public TileLoadStream<Iterator_, Transformer_>
/// Parameters object used to construct generic load stream
typedef typename Base::Params Params;
///
typedef typename Iterator::Scalar Scalar;
//
// Data members
@ -331,6 +334,9 @@ struct PredicatedTileStoreStream : public TileStoreStream<Iterator_, Transformer
/// Parameters object used to construct generic load stream
typedef typename Base::Params Params;
///
typedef typename Iterator::Scalar Scalar;
//
// Data members
//

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -131,7 +131,7 @@ struct TileTraitsContiguousMajor {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Tiling in which warps rake across the contiguous dimension
template <typename Tile_, int Threads>
template <typename Tile_, int Threads, int AccessSize = 1>
struct TileTraitsWarpRake {
/// Shape of tile
typedef Tile_ Tile;
@ -163,10 +163,10 @@ struct TileTraitsWarpRake {
typedef Shape<1, kWarpsStrided, kWarpsContiguous * kWarpSize> ThreadShape;
/// The same warp rakes along the contiguous dimension
typedef Shape<1, kWarpsStrided, kWarpSize> Delta;
typedef Shape<1, kWarpsStrided, kWarpSize * AccessSize> Delta;
/// Number of iterations
typedef Shape<1, Tile::kH / Delta::kH, Tile::kW / ThreadShape::kW> Iterations;
typedef Shape<1, Tile::kH / Delta::kH, (Tile::kW / AccessSize) / ThreadShape::kW> Iterations;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
@ -182,7 +182,7 @@ struct TileTraitsWarpRake {
int warp_w = (warp % kWarpsContiguous);
int warp_h = (warp / kWarpsContiguous);
return make_Coord(0, warp_h, lane + kWarpSpanContiguous * warp_w, 0);
return make_Coord(0, warp_h, AccessSize * (lane + kWarpSpanContiguous * warp_w), 0);
}
};
};

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -84,7 +84,6 @@
* - \p aligned_storage
*
* (4) Functions and types that are STL-like (but aren't in the STL):
* - \p TODO: min and max functors?
*
* The idea is that, as we drop support for older compilers, we can simply #define
* the \p __NV_STD_XYZ macros and \p platform namespace to alias their C++

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -259,6 +259,40 @@ union Vector<uint4_t, kLanes_> {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Vector definition for 4-bit signed integer datatype
template <int kLanes_>
union Vector<int8_t, kLanes_> {
/// The scalar type.
typedef int8_t Scalar;
/// The number of elements in the vector.
enum { kLanes = kLanes_ };
/// The size of the vector.
enum { kVectorSize = kLanes };
/// The number of registers needed to store the vector.
enum { kRegisters = kVectorSize < 4 ? 1 : (kVectorSize+3) / 4 };
// static_assert((kLanes >= 2) && !(kLanes % 2),
// "May only construct vectors of int8_t that are multiples of 8 bits.");
/// The aligned storage to make sure we have good alignment.
AlignedStruct<kVectorSize> aligned_;
/// The data in registers.
uint32_t registers[kRegisters];
/// Default Constructor
CUTLASS_HOST_DEVICE
Vector() {}
/// Constructor to convert from uint32_t type
CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; }
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE int operator[](uint32_t i) const {
return (registers[i / 4] >> (i % 4 * 8) & 0xff);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_>
CUTLASS_HOST_DEVICE void make_zero(Scalar_& x) {
x = Scalar_(0);

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -30,6 +30,10 @@
#if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700)
#define CUTLASS_USE_WMMA_API
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 720)
#define CUTLASS_USE_INT_WMMA
#endif
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750)
#define CUTLASS_USE_SUBBYTE_WMMA
#endif

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -72,6 +72,57 @@ ZipTensorRef<First, Second> make_ZipTensorRef(First const &first, Second const &
return ZipTensorRef<First, Second>(first, second);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Any simple way to do so?
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename First_, typename Second_, typename Third_>
struct Zip3TensorRef {
/// First tensor ref
typedef First_ First;
/// Second tensor ref
typedef Second_ Second;
/// Third tensor ref
typedef Third_ Third;
//
// Data members
//
/// First TensorRef
First first;
/// Second TensorRef
Second second;
/// Third TensorRef
Third third;
//
// Methods
//
CUTLASS_HOST_DEVICE
Zip3TensorRef() {}
CUTLASS_HOST_DEVICE
Zip3TensorRef(First const& _first, Second const& _second, Third const& _third) :
first(_first), second(_second), third(_third) {}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructs a ZipTensorRef
template <typename First, typename Second, typename Third>
CUTLASS_HOST_DEVICE
Zip3TensorRef<First, Second, Third> make_Zip3TensorRef(First const &first,
Second const &second,
Third const &third) {
return Zip3TensorRef<First, Second, Third>(first, second, third);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -47,6 +47,9 @@ class ZipTileIterator {
/// Second iterator type
typedef Second_ Second;
///
typedef typename First::Scalar Scalar;
/// Params object
struct Params {

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

Some files were not shown because too many files have changed in this diff Show More