CUTLASS 2.2 (#96)
Adds support for NVIDIA Ampere Architecture features. CUDA 11 Toolkit recommended.
This commit is contained in:
parent
e33d90b361
commit
86931fef85
16
CHANGELOG.md
16
CHANGELOG.md
|
@ -2,6 +2,22 @@
|
|||
|
||||
# CUTLASS 2.x
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* Fast Tensor Core operations:
|
||||
* Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Tensor Float 32, BFloat16, and double-precision data types
|
||||
* Mixed integer data types (int8, int4, bin1)
|
||||
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
|
||||
* Features:
|
||||
* SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
|
||||
* Complex-valued GEMMs targeting NVIDIA Ampere Tensor Cores in double-precision and Tensor Float 32
|
||||
* Gaussian complex GEMMs using 3m complex multiply algorithm
|
||||
* Universal GEMM kernel supporting two batch modes and two algorithms for parallel reductions
|
||||
* Policy updates:
|
||||
* [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) needed to enable NVIDIA Ampere Architecture features
|
||||
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
|
||||
|
||||
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
* API to launch compiled kernel instances for GEMM and planar complex GEMM
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
@ -32,7 +32,7 @@ endif()
|
|||
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.1.0 LANGUAGES CXX)
|
||||
project(CUTLASS VERSION 2.2.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
|
@ -84,7 +84,7 @@ endif()
|
|||
|
||||
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
|
||||
if (NOT CUDA_VERSION VERSION_LESS 7.5)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 50)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 8.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61)
|
||||
|
@ -98,6 +98,9 @@ endif()
|
|||
if (NOT CUDA_VERSION VERSION_LESS 10.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
|
||||
endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
|
@ -154,7 +157,7 @@ endif()
|
|||
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
|
||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
set(CUTLASS_ENABLE_F16C ON CACHE BOOL "Enable F16C x86 extensions in host code.")
|
||||
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
|
||||
|
||||
#
|
||||
# CUTLASS generator cmake configuration
|
||||
|
@ -248,8 +251,8 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
|
|||
endif()
|
||||
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-pragma-unroll-threshold=100000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-unroll-threshold=5000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
|
||||
|
||||
string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
|
@ -271,7 +274,7 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
|
|||
set(NVCC_FLAGS)
|
||||
set(CLANG_FLAGS)
|
||||
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
|
||||
list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
|
||||
set(CODES)
|
||||
if(CUTLASS_NVCC_EMBED_CUBIN)
|
||||
list(APPEND CODES sm_${ARCH})
|
||||
|
|
|
@ -9,15 +9,17 @@ This is the official list of CUTLASS developers and contributors.
|
|||
## DEVELOPERS
|
||||
Andrew Kerr
|
||||
Haicheng Wu
|
||||
Naila Farooqui
|
||||
Manish Gupta
|
||||
Dustyn Blasig
|
||||
Pradeep Ramani
|
||||
Manish Gupta
|
||||
Aditya Atluri
|
||||
Naila Farooqui
|
||||
Piotr Majcher
|
||||
Paul Springer
|
||||
David Tanner
|
||||
Scott Yokim
|
||||
Jin Wang
|
||||
Scott Yokim
|
||||
Markus Hohnerbach
|
||||
Aditya Atluri
|
||||
David Tanner
|
||||
|
||||
## CONTRIBUTORS
|
||||
Timothy Costa
|
||||
|
@ -25,12 +27,10 @@ Julien Demouth
|
|||
Brian Fahs
|
||||
Michael Goldfarb
|
||||
Mostafa Hagog
|
||||
Markus Hohnerbach
|
||||
Fei Hu
|
||||
Alan Kaatz
|
||||
Tina Li
|
||||
Timmy Liu
|
||||
Piotr Majcher
|
||||
Duane Merrill
|
||||
Kevin Siu
|
||||
Markus Tavenrath
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
@ -206,14 +206,14 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
|
|||
function(cutlass_correct_source_file_language_property)
|
||||
if(CUDA_COMPILER MATCHES "clang")
|
||||
foreach(File ${ARGN})
|
||||
if(${File} MATCHES ".*\.cu$")
|
||||
if(File MATCHES ".*\.cu$")
|
||||
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED ON CACHE BOOL "Enable combined source compilation")
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation")
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
|
||||
|
||||
function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
Copyright (c) 2017 - 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
|
98
README.md
98
README.md
|
@ -1,8 +1,8 @@
|
|||
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
|
||||
|
||||
# CUTLASS 2.1
|
||||
# CUTLASS 2.2
|
||||
|
||||
_CUTLASS 2.1 - April 2020_
|
||||
_CUTLASS 2.2 - June 2020_
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
|
@ -17,14 +17,28 @@ and applications.
|
|||
To support a wide variety of applications, CUTLASS provides extensive support for
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for half-precision floating
|
||||
point (FP16), single-precision floating point (FP32), double-precision floating
|
||||
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32), double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations for
|
||||
|
||||
Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta and Turing architectures.
|
||||
NVIDIA's Volta, Turing, and Ampere architectures.
|
||||
|
||||
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
See the [functionality listing](media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.2
|
||||
|
||||
CUTLASS 2.2 is a significant update to CUTLASS adding:
|
||||
|
||||
- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
|
||||
- Deep software pipelines using asynchronous copy
|
||||
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
|
||||
# What's New in CUTLASS 2.1
|
||||
|
||||
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
|
||||
|
@ -32,7 +46,6 @@ CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
|
|||
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
|
||||
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
|
||||
|
||||
# What's New in CUTLASS 2.0
|
||||
|
||||
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
|
||||
|
@ -43,9 +56,6 @@ CUTLASS 2.0 is a substantial refactoring from the previous version, intended to
|
|||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for more details.**
|
||||
|
||||
See the [functionality listing](media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=/media/images/cutlass-performance-plot.png></p>
|
||||
|
@ -53,15 +63,15 @@ supported at each level of the execution model hierarchy.
|
|||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions on an NVIDIA GeForce 2080 Ti and an NVIDIA TitanV
|
||||
using CUDA 10.2. Tensor Core operations are implemented using CUDA's
|
||||
for large matrix dimensions on an NVIDIA GeForce 2080 Ti, an NVIDIA A100, and an NVIDIA TitanV
|
||||
using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires a C++11 host compiler and
|
||||
performs best when built with the [CUDA 10.2 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is compatible with CUDA 9.2, CUDA 10.0, and CUDA 10.1.
|
||||
performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2.
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
|
@ -70,27 +80,28 @@ We have tested the following environments.
|
|||
| Windows 10 | Microsoft Visual Studio 2015|
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| Ubuntu 16.04 | GCC 5.4.0 |
|
||||
| Ubuntu 18.04 | GCC 7.3.0 |
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
|
||||
Additionally, CUTLASS may be built with clang.
|
||||
See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Maxwell-, Pascal-, Volta-, or Turing- architecture NVIDIA GPU.
|
||||
any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|
||||
|**GPU**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|---|---|---|
|
||||
|NVIDIA GeForce 1080|9.2| |
|
||||
|NVIDIA TitanXP|9.2| |
|
||||
|NVIDIA Tesla P100|9.2| |
|
||||
|NVIDIA Tesla V100|9.2|10.1|
|
||||
|NVIDIA TitanV|9.2|10.1|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|10.0|10.2|
|
||||
|NVIDIA Tesla T4|10.0|10.2|
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|---|---|---|---|
|
||||
|NVIDIA Tesla P100|6.0|9.2| |
|
||||
|NVIDIA GeForce 1080|6.1|9.2| |
|
||||
|NVIDIA TitanXP|6.1|9.2| |
|
||||
|NVIDIA Tesla V100|7.0|9.2|10.1|
|
||||
|NVIDIA TitanV|7.0|9.2|10.1|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|
||||
|NVIDIA Tesla T4|7.5|10.0|10.2|
|
||||
|NVIDIA A100|8.0|11.0|11.0|
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS 2.1 is described in the following documents and the accompanying
|
||||
CUTLASS 2.2 is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
|
||||
|
@ -124,7 +135,7 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
|
|||
```
|
||||
|
||||
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify
|
||||
the architectures to build CUTLASS for by changing the CMake configuration setting
|
||||
`CUTLASS_NVCC_ARCHS`.
|
||||
|
||||
|
@ -210,6 +221,10 @@ examples/
|
|||
10_planar_complex/ # example demonstrating planar complex GEMM kernels
|
||||
|
||||
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
|
||||
|
||||
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
|
||||
|
||||
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
|
||||
```
|
||||
|
||||
### Tools
|
||||
|
@ -255,29 +270,32 @@ $ make cutlass_profiler -j
|
|||
|
||||
Example command line for profiling SGEMM kernels is as follows:
|
||||
```
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=4352 --n=4096 --k=4096
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
|
||||
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
Operation: cutlass_simt_sgemm_128x128_nn
|
||||
Provider: CUTLASS
|
||||
OperationKind: gemm
|
||||
Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1
|
||||
|
||||
Disposition: Passed
|
||||
Status: Success
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
Arguments: --m=4352 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 \
|
||||
--split_k_slices=1 --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 \
|
||||
--stages=2 --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 \
|
||||
--max_cc=1024
|
||||
cuBLAS: Passed
|
||||
|
||||
Bytes: 52428800 bytes
|
||||
FLOPs: 146064539648 flops
|
||||
Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \
|
||||
--batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
|
||||
--warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
|
||||
|
||||
Runtime: 10.5424 ms
|
||||
Memory: 4.63158 GiB/s
|
||||
Bytes: 180355072 bytes
|
||||
FLOPs: 115992428544 flops
|
||||
|
||||
Math: 13854.9 GFLOP/s
|
||||
Runtime: 6.73655 ms
|
||||
Memory: 24.934 GiB/s
|
||||
|
||||
Math: 17218.4 GFLOP/s
|
||||
```
|
||||
|
||||
[Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
68
cuBLAS.cmake
68
cuBLAS.cmake
|
@ -10,28 +10,35 @@ if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR
|
|||
message(STATUS "cuBLAS Disabled.")
|
||||
|
||||
elseif(NOT TARGET cublas)
|
||||
|
||||
|
||||
find_path(
|
||||
_CUBLAS_INCLUDE_DIR cublas.h
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/include
|
||||
$ENV{CUBLAS_PATH}/include
|
||||
$ENV{CUDA_PATH}/include
|
||||
${CUBLAS_PATH}/include
|
||||
/usr/include)
|
||||
_CUBLAS_INCLUDE_DIR
|
||||
NAMES cublas.h
|
||||
HINTS
|
||||
${CUBLAS_INCLUDE_PATH}
|
||||
ENV CUBLAS_INCLUDE_PATH
|
||||
${CUBLAS_PATH}
|
||||
ENV CUBLAS_PATH
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
include
|
||||
)
|
||||
|
||||
find_library(
|
||||
_CUBLAS_LIBRARY cublas
|
||||
_CUBLAS_LIBRARY
|
||||
NAMES cublas
|
||||
HINTS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
|
||||
$ENV{CUBLAS_PATH}/lib64
|
||||
$ENV{CUBLAS_PATH}/lib/x64
|
||||
$ENV{CUDA_PATH}/lib64
|
||||
$ENV{CUDA_PATH}/lib/x64
|
||||
${CUBLAS_PATH}/lib64
|
||||
${CUBLAS_PATH}/lib/x64
|
||||
/usr/lib/x86_64-linux-gnu)
|
||||
${CUBLAS_LIBRARY_PATH}
|
||||
ENV CUBLAS_LIBRARY_PATH
|
||||
${_CUBLAS_INCLUDE_DIR}/..
|
||||
${CUBLAS_PATH}
|
||||
ENV CUBLAS_PATH
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib64
|
||||
lib/x64
|
||||
lib
|
||||
)
|
||||
|
||||
if(_CUBLAS_INCLUDE_DIR AND _CUBLAS_LIBRARY)
|
||||
|
||||
|
@ -79,17 +86,20 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
|
|||
$<BUILD_INTERFACE:${CUBLAS_INCLUDE_DIR}>)
|
||||
|
||||
find_library(
|
||||
_CUBLASLT_LIBRARY cublasLt
|
||||
_CUBLASLT_LIBRARY
|
||||
NAMES cublasLt
|
||||
HINTS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
|
||||
$ENV{CUBLAS_PATH}/lib64
|
||||
$ENV{CUBLAS_PATH}/lib/x64
|
||||
$ENV{CUDA_PATH}/lib64
|
||||
$ENV{CUDA_PATH}/lib/x64
|
||||
${CUBLAS_PATH}/lib64
|
||||
${CUBLAS_PATH}/lib/x64
|
||||
/usr/lib/x86_64-linux-gnu)
|
||||
${CUBLAS_LIBRARY_PATH}
|
||||
ENV CUBLAS_LIBRARY_PATH
|
||||
${_CUBLAS_INCLUDE_DIR}/..
|
||||
${CUBLAS_PATH}
|
||||
ENV CUBLAS_PATH
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib64
|
||||
lib/x64
|
||||
lib
|
||||
)
|
||||
|
||||
if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt)
|
||||
|
||||
|
@ -106,6 +116,8 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
|
|||
|
||||
add_library(nvidia::cublasLt ALIAS cublasLt)
|
||||
|
||||
target_link_libraries(cublas INTERFACE cublasLt)
|
||||
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
*modification, are permitted provided that the following conditions are met:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -34,6 +34,8 @@
|
|||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor_op_multiplicand_sm70.h"
|
||||
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
|
||||
#include "cutlass/layout/tensor_op_multiplicand_sm80.h"
|
||||
|
||||
#include "visualize_layout.h"
|
||||
#include "register_layout.h"
|
||||
|
||||
|
@ -59,18 +61,40 @@ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase>
|
|||
// Integer matrix multiply.int4 8832 TN kblock128
|
||||
{"TensorOpMultiplicand<4,128>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 128>>},
|
||||
// Integer matrix multiply.int4 16864 TN kblock256
|
||||
{"TensorOpMultiplicand<4,256>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 256>>},
|
||||
// Integer matrix multiply 8816 Interleaved-32
|
||||
{"TensorOpMultiplicand<8,32>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 32>>},
|
||||
// Integer matrix multiply 8816 TN kblock64
|
||||
{"TensorOpMultiplicand<8,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 64>>},
|
||||
{"TensorOpMultiplicand<8,128>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 128>>},
|
||||
// Matrix Multiply 1688 TN kblock32
|
||||
{"TensorOpMultiplicand<16,32>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 32>>},
|
||||
// Matrix multiply 1688 NT
|
||||
{"TensorOpMultiplicand<16,64>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 64>>},
|
||||
// Matrix multiply 1688.TF32 TN kblock16
|
||||
{"TensorOpMultiplicand<32,16>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<32, 16>>},
|
||||
// Matrix multiply 1688.TF32 TN kblock32
|
||||
{"TensorOpMultiplicand<32,32>",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<32, 32>>},
|
||||
// Matrix multiply 1688 NT
|
||||
{"TensorOpMultiplicandCongruous<32,32>",
|
||||
new VisualizeLayout<
|
||||
cutlass::layout::TensorOpMultiplicandCongruous<32, 32>>},
|
||||
// Matrix multiply 884 NT
|
||||
{"TensorOpMultiplicandCongruous<64,16>",
|
||||
new VisualizeLayout<
|
||||
cutlass::layout::TensorOpMultiplicandCongruous<64, 16>>},
|
||||
// Matrix multiply 884 TN
|
||||
{"TensorOpMultiplicand64bCrosswise",
|
||||
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand64bCrosswise>},
|
||||
{"TensorOpMultiplicandCongruous<128,4>",
|
||||
new VisualizeLayout<
|
||||
cutlass::layout::TensorOpMultiplicandCongruous<128, 4>>},
|
||||
|
@ -82,7 +106,7 @@ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase>
|
|||
cutlass::layout::VoltaTensorOpMultiplicandCongruous<16>>},
|
||||
{"VoltaTensorOpMultiplicandCrosswise<16,32>",
|
||||
new VisualizeLayout<
|
||||
cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>},
|
||||
cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>}
|
||||
};
|
||||
|
||||
for (auto layout : layout_pairs) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -65,14 +65,26 @@ void print_usage(std::ostream &out) {
|
|||
"--extent=64,64 --vectorize=32 --output-shape=256,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicand<4,128>\" "
|
||||
"--extent=128,32 --vectorize=32 --output-shape=256,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicand<4,256>\" "
|
||||
"--extent=256,16 --vectorize=32 --output-shape=256,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicand<8,32>\" "
|
||||
"--extent=32,64 --vectorize=16 --output-shape=128,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicand<8,64>\" "
|
||||
"--extent=64,32 --vectorize=16 --output-shape=128,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicand<8,128>\" "
|
||||
"--extent=128,16 --vectorize=16 --output-shape=128,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicand<16,32>\" "
|
||||
"--extent=32,32 --vectorize=8 --output-shape=64,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicand<16,64>\" "
|
||||
"--extent=64,16 --vectorize=8 --output-shape=64,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicand<32,16>\" "
|
||||
"--extent=16,32 --vectorize=4 --output-shape=32,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicand<32,32>\" "
|
||||
"--extent=32,16 --vectorize=4 --output-shape=32,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicandCongruous<32,32>\" "
|
||||
"--extent=32,16 --vectorize=4 --output-shape=32,4\n"
|
||||
<< "$ 03_visualize_layout \"TensorOpMultiplicandCongruous<64, 16>\" "
|
||||
"--extent=16,16 --vectorize=2 --output-shape=16,4\n"
|
||||
<< "$ 03_visualize_layout \"VoltaTensorOpMultiplicandCrosswise<16,32>\" "
|
||||
"--extent=32,64 --vectorize=4 --output-shape=64,4\n"
|
||||
<< "$ 03_visualize_layout \"VotlaTensorOpMultiplicandCongruous<16>\" "
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -39,7 +39,7 @@ inner product (1/16th of output), they accumulate to single output matrix.
|
|||
|
||||
Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing
|
||||
high performance kernels at scale which works for multiple problem sizes with good abstractions is
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose
|
||||
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
|
||||
multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU
|
||||
easily.
|
||||
|
||||
|
@ -144,7 +144,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
|
|||
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// This code section describes ?
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
@ -172,17 +172,7 @@ using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
|
|||
ShapeMMAOp,
|
||||
EpilogueOp>;
|
||||
|
||||
int main() {
|
||||
|
||||
//
|
||||
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
int run() {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
|
@ -316,11 +306,30 @@ int main() {
|
|||
tensor_ref_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
|
||||
tensor_ref_d.host_view())
|
||||
? "Passed"
|
||||
: "Failed")
|
||||
<< std::endl;
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
//
|
||||
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero, so this test passes when built with older CUDA Toolkits. Its action are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -156,7 +156,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
|
|||
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// This code section describes ?
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
@ -188,15 +188,7 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
|||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
int main() {
|
||||
|
||||
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
int run() {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
|
@ -223,7 +215,7 @@ int main() {
|
|||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.nk()); // <- Create matrix B with dimensions N x K
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
|
@ -326,12 +318,28 @@ int main() {
|
|||
tensor_ref_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
|
||||
tensor_ref_d.host_view())
|
||||
? "Passed"
|
||||
: "Failed")
|
||||
<< std::endl;
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
return 0;
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero when built on older Toolkits so tests pass. The actions of this SDK example are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -150,12 +150,12 @@ using SmArch = cutlass::arch::Sm75;
|
|||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<128, 256, 64>; // <- threadblock tile M = 128, N = 256, K = 64
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 16
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 64
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ??
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// This code section describes the epilogue part of the kernel
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
||||
|
@ -186,7 +186,7 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
|||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
int main() {
|
||||
int run() {
|
||||
|
||||
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 10.2.
|
||||
|
@ -222,7 +222,7 @@ int main() {
|
|||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.nk()); // <- Create matrix B with dimensions N x K
|
||||
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
||||
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
|
@ -325,12 +325,28 @@ int main() {
|
|||
tensor_ref_d.sync_host();
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
|
||||
tensor_ref_d.host_view())
|
||||
? "Passed"
|
||||
: "Failed")
|
||||
<< std::endl;
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
tensor_d.host_view(),
|
||||
tensor_ref_d.host_view());
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
return 0;
|
||||
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
||||
|
||||
return (passed ? 0 : -1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
|
||||
// in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero so this test passes when built on older Toolkits.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -500,7 +500,9 @@ int main(int argc, char const **args) {
|
|||
if (props.major < 7) {
|
||||
std::cerr << "Volta Tensor Core operations must be run on a machine with compute capability at least 70."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
|
||||
// Returning zero so this test passes on older architectures even though its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else if (props.major == 7 && props.minor <= 2) {
|
||||
//
|
||||
|
@ -508,7 +510,9 @@ int main(int argc, char const **args) {
|
|||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits even though its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else if (props.major == 7 && props.minor >= 5) {
|
||||
|
@ -517,7 +521,9 @@ int main(int argc, char const **args) {
|
|||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits even though its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -560,7 +560,9 @@ int main(int argc, char const **args) {
|
|||
if (props.major < 7) {
|
||||
std::cerr << "Tensor Core operations must be run on a machine with compute capability at least 70."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
|
||||
// Returning zero so this passes on older architectures. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else if (props.major == 7 && props.minor <= 2) {
|
||||
//
|
||||
|
@ -568,7 +570,9 @@ int main(int argc, char const **args) {
|
|||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
|
||||
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
|
||||
// Returning zero so this passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else if (props.major == 7 && props.minor >= 5) {
|
||||
|
@ -577,7 +581,9 @@ int main(int argc, char const **args) {
|
|||
//
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
return -1;
|
||||
|
||||
// Returning zero so this passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
12_gemm_bias_relu
|
||||
gemm_bias_relu.cu
|
||||
)
|
||||
|
|
@ -0,0 +1,282 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "helper.h"
|
||||
|
||||
// The code section below describes datatype for input, output matrices and computation between
|
||||
// elements in input matrices.
|
||||
using ElementAccumulator = float; // <- data type of accumulator
|
||||
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
|
||||
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
|
||||
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
|
||||
using ElementOutput = float; // <- data type of elements in output matrix D
|
||||
|
||||
// The code section below describes matrix layout of input and output matrices. Column Major for
|
||||
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
||||
using LayoutInputA = cutlass::layout::ColumnMajor;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
|
||||
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
||||
using MMAOp = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
// This code section describes CUDA SM architecture number
|
||||
using SmArch = cutlass::arch::Sm75;
|
||||
|
||||
// This code section describes the tile size a thread block will compute
|
||||
using ShapeMMAThreadBlock =
|
||||
cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32
|
||||
// This code section describes tile size a warp will compute
|
||||
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32
|
||||
// This code section describes the size of MMA op
|
||||
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4
|
||||
|
||||
// This code section describes how threadblocks are scheduled on GPU
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
|
||||
|
||||
// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
|
||||
//
|
||||
// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + beta * c_ij )
|
||||
//
|
||||
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput, // <- data type of output matrix
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- this is the number of elements per
|
||||
// vectorized memory access. For half
|
||||
// precision, it's 8 elements. This becomes
|
||||
// the vector width of math instructions in
|
||||
// epilogue too
|
||||
ElementAccumulator, // <- data type of accumulator
|
||||
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
||||
|
||||
// Number of pipelines you want to use
|
||||
constexpr int NumStages = 2;
|
||||
|
||||
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementAccumulator,
|
||||
MMAOp,
|
||||
SmArch,
|
||||
ShapeMMAThreadBlock,
|
||||
ShapeMMAWarp,
|
||||
ShapeMMAOp,
|
||||
EpilogueOp,
|
||||
SwizzleThreadBlock,
|
||||
NumStages>;
|
||||
|
||||
int run() {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int length_m = 5120;
|
||||
const int length_n = 4096;
|
||||
const int length_k = 4096;
|
||||
|
||||
// Create a tuple of problem size for matrix multiplication
|
||||
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||
|
||||
// Initialize tensors using CUTLASS helper functions
|
||||
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
||||
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
||||
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
||||
problem_size.nk()); // <- Create matrix B with dimensions N x K
|
||||
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias(
|
||||
{problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1
|
||||
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// CUTLASS kernel
|
||||
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
||||
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
||||
// reference kernel
|
||||
|
||||
// Fill input and output matrices on host using CUTLASS helper functions
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_a.host_view(),
|
||||
1,
|
||||
ElementInputA(4),
|
||||
ElementInputA(-4),
|
||||
0); // <- Fill matrix A on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_b.host_view(),
|
||||
1,
|
||||
ElementInputB(4),
|
||||
ElementInputB(-4),
|
||||
0); // <- Fill matrix B on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
tensor_c_bias.host_view(),
|
||||
1,
|
||||
ElementOutput(4),
|
||||
ElementOutput(-4),
|
||||
0); // <- Fill matrix C on host with uniform-distribution random data
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
||||
|
||||
// Copy data from host to GPU
|
||||
tensor_a.sync_device();
|
||||
tensor_b.sync_device();
|
||||
tensor_c_bias.sync_device();
|
||||
tensor_d.sync_device();
|
||||
tensor_ref_d.sync_device();
|
||||
|
||||
// Initialize alpha and beta for dot product computation
|
||||
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
||||
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
||||
|
||||
// Split K dimension into 1 partitions
|
||||
int split_k_slices = 1;
|
||||
|
||||
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
||||
// instantiated CUTLASS kernel
|
||||
typename Gemm::Arguments arguments{
|
||||
problem_size, // <- problem size of matrix multiplication
|
||||
tensor_a.device_ref(), // <- reference to matrix A on device
|
||||
tensor_b.device_ref(), // <- reference to matrix B on device
|
||||
|
||||
{tensor_c_bias.device_data(), 0}, // <- the C matrix is treated as the bias vector. We can enable the GEMM
|
||||
// to project away the N dimension by setting the stride to zero.
|
||||
|
||||
tensor_d.device_ref(), // <- reference to matrix D on device
|
||||
{alpha, beta}, // <- tuple of alpha and beta
|
||||
split_k_slices}; // <- k-dimension split factor
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Instantiate CUTLASS kernel depending on templates
|
||||
Gemm gemm_op;
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
// Launch initialized CUTLASS kernel
|
||||
status = gemm_op();
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
//
|
||||
// Create instantiation for device reference gemm kernel
|
||||
//
|
||||
|
||||
cutlass::reference::device::Gemm<ElementInputA,
|
||||
LayoutInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
ElementComputeEpilogue,
|
||||
ElementComputeEpilogue>
|
||||
gemm_device_reference;
|
||||
|
||||
// Launch device reference to compute strictly the product A * B
|
||||
gemm_device_reference(
|
||||
problem_size,
|
||||
alpha,
|
||||
tensor_a.device_ref(),
|
||||
tensor_b.device_ref(),
|
||||
0,
|
||||
tensor_c_bias.device_ref(),
|
||||
tensor_ref_d.device_ref());
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Copy output data from CUTLASS and reference kernel to host for comparison
|
||||
tensor_d.sync_host();
|
||||
tensor_ref_d.sync_host();
|
||||
|
||||
// Compute bias + relu in host code
|
||||
for (int i = 0; i < problem_size.m(); ++i) {
|
||||
for (int j = 0; j < problem_size.n(); ++j) {
|
||||
tensor_ref_d.at({i, j}) = std::max(
|
||||
ElementOutput(0),
|
||||
ElementOutput(tensor_ref_d.at({i, j}) + beta * tensor_c_bias.at({i, 0}))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
||||
std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
|
||||
tensor_ref_d.host_view())
|
||||
? "Passed"
|
||||
: "Failed")
|
||||
<< std::endl;
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cutlass_example_add_executable(
|
||||
13_fused_two_gemms
|
||||
fused_gemm.cu
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
13_fused_two_gemms
|
||||
PRIVATE
|
||||
.
|
||||
)
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_gemm_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_nonfused_gemm_f16() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
}
|
||||
|
||||
void run_fused_gemm_f16() {
|
||||
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
cutlass::half_t,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::half_t,
|
||||
cutlass::layout::ColumnMajor,
|
||||
ElementOutput,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bFusedGemmRun<B2bGemm> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back FP16 TN GEMMs...\n";
|
||||
bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif //#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
|
@ -0,0 +1,608 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_relu.h"
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
if((val1) <= (val2)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
||||
#define CHECK_TRUE(val) \
|
||||
if(!(val)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Gemm0_, typename Gemm1_>
|
||||
struct B2bNonFusedGemmRun
|
||||
{
|
||||
|
||||
using Gemm0 = Gemm0_;
|
||||
using Gemm1 = Gemm1_;
|
||||
using ElementAccumulator = typename Gemm0::ElementAccumulator;
|
||||
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
B2bNonFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D1.host_view());
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_D0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename Gemm0::Arguments arguments_0{
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_D0.device_ref(),
|
||||
{alpha0, beta0}
|
||||
};
|
||||
|
||||
typename Gemm1::Arguments arguments_1{
|
||||
problem_size_1,
|
||||
tensor_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
tensor_D1.device_ref(),
|
||||
{alpha1, beta1}
|
||||
};
|
||||
|
||||
|
||||
Gemm0 gemm_op_0;
|
||||
Gemm1 gemm_op_1;
|
||||
|
||||
cutlass::Status status = gemm_op_0.initialize(arguments_0);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = gemm_op_1.initialize(arguments_1);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop1, stop2;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop1);
|
||||
cudaEventCreate(&stop2);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
for(int i = 0; i < 100; i++) {
|
||||
|
||||
status = gemm_op_1();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop2);
|
||||
cudaDeviceSynchronize();
|
||||
float gemm0Time, gemm1Time, totalTime;
|
||||
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
||||
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
|
||||
std::cout << "total time " << totalTime / 100.0 << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
cutlass::reference::device::Gemm<
|
||||
typename Gemm0::ElementA, typename Gemm0::LayoutA,
|
||||
typename Gemm0::ElementB, typename Gemm0::LayoutB,
|
||||
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename Gemm0::Operator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename Gemm1::ElementA, typename Gemm1::LayoutA,
|
||||
typename Gemm1::ElementB, typename Gemm1::LayoutB,
|
||||
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename Gemm1::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
// Wait for kernels to finish
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed) {
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bGemm_device_nonfused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename B2bGemm_>
|
||||
struct B2bFusedGemmRun
|
||||
{
|
||||
|
||||
using B2bGemm = B2bGemm_;
|
||||
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
|
||||
using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
B2bFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
// cutlass::HostTensor<
|
||||
// typename B2bGemm::ElementC,
|
||||
// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D1.host_view());
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
tensor_D1.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
status = b2b_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "time " << gemmTime / 100.0 << " ms\n";
|
||||
|
||||
//tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_0, reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed) {
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bGemm_device_fused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
// << "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
|
@ -0,0 +1,190 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/gemm.h"
|
||||
|
||||
#include "device/b2b_gemm.h"
|
||||
#include "b2b_interleaved_gemm_run.h"
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_nonfused_gemm_s8() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 32, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using Gemm0 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
WarpShape0,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
using Gemm1 = cutlass::gemm::device::Gemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape1,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
|
||||
|
||||
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(pass)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
}
|
||||
|
||||
void run_fused_gemm_s8() {
|
||||
|
||||
using ElementOutput = int8_t;
|
||||
using ElementAccumulator = int32_t;
|
||||
using ElementCompute = float;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
|
||||
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
|
||||
ElementCompute alpha0 = ElementCompute(2);
|
||||
ElementCompute beta0 = ElementCompute(0);
|
||||
ElementCompute alpha1 = ElementCompute(2);
|
||||
ElementCompute beta1 = ElementCompute(1);
|
||||
|
||||
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
|
||||
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
using EpilogueOutputOp0 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
InstructionShape::kM * InstructionShape::kN / 32,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
using EpilogueOutputOp1 =
|
||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
||||
ElementOutput,
|
||||
64 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
|
||||
|
||||
using B2bGemm = cutlass::gemm::device::B2bGemm<
|
||||
int8_t,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
int8_t,
|
||||
cutlass::layout::RowMajorInterleaved<32>,
|
||||
ElementOutput,
|
||||
cutlass::layout::ColumnMajorInterleaved<32>,
|
||||
ElementAccumulator,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||||
2
|
||||
>;
|
||||
|
||||
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
|
||||
|
||||
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
|
||||
bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
|
||||
if(passed)
|
||||
std::cout << "Pass\n";
|
||||
else
|
||||
std::cout << "Fail\n";
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
|
@ -0,0 +1,633 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/host_reorder.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "helper.h"
|
||||
|
||||
#define CHECK_GT(val1, val2) \
|
||||
if((val1) <= (val2)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
||||
#define CHECK_TRUE(val) \
|
||||
if(!(val)) \
|
||||
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
||||
|
||||
template <typename Gemm0_, typename Gemm1_, int InterleavedK_>
|
||||
struct B2bInterleavedNonFusedGemmRun
|
||||
{
|
||||
|
||||
using Gemm0 = Gemm0_;
|
||||
using Gemm1 = Gemm1_;
|
||||
using ElementAccumulator = typename Gemm0::ElementAccumulator;
|
||||
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
B2bInterleavedNonFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementA,
|
||||
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementB,
|
||||
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm0::ElementC,
|
||||
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementB,
|
||||
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename Gemm1::ElementC,
|
||||
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
|
||||
//Reorder B0 and B1
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D1.host_view());
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
tensor_D0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename Gemm0::Arguments arguments_0{
|
||||
problem_size_0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0_reordered.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_D0.device_ref(),
|
||||
{alpha0, beta0}
|
||||
};
|
||||
|
||||
typename Gemm1::Arguments arguments_1{
|
||||
problem_size_1,
|
||||
tensor_D0.device_ref(),
|
||||
tensor_B1_reordered.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
tensor_D1.device_ref(),
|
||||
{alpha1, beta1}
|
||||
};
|
||||
|
||||
|
||||
Gemm0 gemm_op_0;
|
||||
Gemm1 gemm_op_1;
|
||||
|
||||
cutlass::Status status = gemm_op_0.initialize(arguments_0);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
status = gemm_op_1.initialize(arguments_1);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
cudaEvent_t start, stop1, stop2;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop1);
|
||||
cudaEventCreate(&stop2);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
status = gemm_op_0();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
cudaEventRecord(stop1);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
status = gemm_op_1();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop2);
|
||||
cudaDeviceSynchronize();
|
||||
float gemm0Time, gemm1Time, totalTime;
|
||||
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
||||
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
||||
cudaEventElapsedTime(&totalTime, start, stop2);
|
||||
std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
|
||||
std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
|
||||
std::cout << "total time " << totalTime / 100.0 << " ms\n";
|
||||
|
||||
tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
cutlass::reference::device::Gemm<
|
||||
typename Gemm0::ElementA, typename Gemm0::LayoutA,
|
||||
typename Gemm0::ElementB, typename Gemm0::LayoutB,
|
||||
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename Gemm0::Operator>
|
||||
reference_gemm_0;
|
||||
|
||||
cutlass::reference::device::Gemm<
|
||||
typename Gemm1::ElementA, typename Gemm1::LayoutA,
|
||||
typename Gemm1::ElementB, typename Gemm1::LayoutB,
|
||||
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename Gemm1::Operator>
|
||||
reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
tensor_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed) {
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bGemm_device_interleaved_nonfused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
<< "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename B2bGemm_, int InterleavedK_>
|
||||
struct B2bInterleavedFusedGemmRun
|
||||
{
|
||||
|
||||
using B2bGemm = B2bGemm_;
|
||||
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
|
||||
using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
|
||||
|
||||
/// Initialization
|
||||
cutlass::Distribution::Kind init_A;
|
||||
cutlass::Distribution::Kind init_B;
|
||||
cutlass::Distribution::Kind init_C;
|
||||
uint64_t seed;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
B2bInterleavedFusedGemmRun(
|
||||
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
||||
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
||||
uint64_t seed_ = 2080
|
||||
):
|
||||
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
||||
|
||||
/// Helper to initialize a tensor view
|
||||
template <typename Element, typename Layout>
|
||||
bool initialize_tensor(
|
||||
cutlass::TensorView<Element, Layout> view,
|
||||
cutlass::Distribution::Kind dist_kind,
|
||||
uint64_t seed) {
|
||||
|
||||
if (dist_kind == cutlass::Distribution::Uniform) {
|
||||
|
||||
cutlass::reference::host::TensorFillRandomUniform(
|
||||
view, seed, 2, -2, 0);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Identity) {
|
||||
|
||||
cutlass::reference::host::TensorFillIdentity(view);
|
||||
}
|
||||
else if (dist_kind == cutlass::Distribution::Sequential) {
|
||||
|
||||
cutlass::reference::host::BlockFillSequential(
|
||||
view.data(), view.capacity());
|
||||
}
|
||||
else {
|
||||
// TODO: Implement the rest
|
||||
std::cerr << "Not implemented\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/// Executes one test
|
||||
bool run(
|
||||
cutlass::gemm::GemmCoord problem_size_0,
|
||||
cutlass::gemm::GemmCoord problem_size_1,
|
||||
ElementCompute alpha0 = ElementCompute(1),
|
||||
ElementCompute beta0 = ElementCompute(0),
|
||||
ElementCompute alpha1 = ElementCompute(1),
|
||||
ElementCompute beta1 = ElementCompute(0),
|
||||
bool relu = true) {
|
||||
|
||||
//
|
||||
// Allocate the GEMM workspace
|
||||
//
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementA,
|
||||
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
|
||||
|
||||
// cutlass::HostTensor<
|
||||
// typename B2bGemm::ElementC,
|
||||
// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementB,
|
||||
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
|
||||
|
||||
cutlass::HostTensor<
|
||||
typename B2bGemm::ElementC,
|
||||
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
|
||||
|
||||
|
||||
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
||||
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
||||
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
||||
|
||||
//Reorder B0
|
||||
cutlass::reorder_column<B2bGemm::InstructionShape::kK>(
|
||||
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
|
||||
cutlass::reorder_column<InterleavedK_>(
|
||||
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
|
||||
|
||||
cutlass::reference::host::TensorFill(
|
||||
tensor_D1.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D0.host_view());
|
||||
cutlass::reference::host::TensorFill(
|
||||
reference_D1.host_view());
|
||||
|
||||
tensor_A0.sync_device();
|
||||
tensor_B0.sync_device();
|
||||
tensor_B0_reordered.sync_device();
|
||||
tensor_C0.sync_device();
|
||||
//tensor_D0.sync_device();
|
||||
tensor_B1.sync_device();
|
||||
tensor_B1_reordered.sync_device();
|
||||
tensor_C1.sync_device();
|
||||
tensor_D1.sync_device();
|
||||
reference_D0.sync_device();
|
||||
reference_D1.sync_device();
|
||||
|
||||
//
|
||||
// Initialize the GEMM operator
|
||||
//
|
||||
|
||||
typename B2bGemm::Arguments arguments{
|
||||
problem_size_0,
|
||||
problem_size_1,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0_reordered.device_ref(),
|
||||
tensor_C0.device_ref(),
|
||||
tensor_B1_reordered.device_ref(),
|
||||
tensor_C1.device_ref(),
|
||||
tensor_D1.device_ref(),
|
||||
{alpha0, beta0},
|
||||
{alpha1, beta1},
|
||||
1, /*threadblock_swizzle_k_tile*/
|
||||
};
|
||||
|
||||
B2bGemm b2b_gemm_op;
|
||||
|
||||
cutlass::Status status = b2b_gemm_op.initialize(arguments);
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
|
||||
//
|
||||
// Run the GEMM
|
||||
//
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventCreate(&stop);
|
||||
|
||||
cudaEventRecord(start);
|
||||
|
||||
for(int i = 0; i < 100; i++) {
|
||||
status = b2b_gemm_op();
|
||||
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
cudaEventRecord(stop);
|
||||
cudaDeviceSynchronize();
|
||||
float gemmTime;
|
||||
cudaEventElapsedTime(&gemmTime, start, stop);
|
||||
std::cout << "time " << gemmTime / 100.0 << " ms\n";
|
||||
|
||||
//tensor_D0.sync_host();
|
||||
tensor_D1.sync_host();
|
||||
|
||||
//
|
||||
// Verify
|
||||
//
|
||||
cutlass::reference::device::Gemm<
|
||||
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
|
||||
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
|
||||
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
|
||||
ElementAccumulator, typename B2bGemm::Operator>
|
||||
reference_gemm_0, reference_gemm_1;
|
||||
|
||||
reference_gemm_0(
|
||||
problem_size_0,
|
||||
alpha0,
|
||||
tensor_A0.device_ref(),
|
||||
tensor_B0.device_ref(),
|
||||
beta0,
|
||||
tensor_C0.device_ref(),
|
||||
reference_D0.device_ref()
|
||||
);
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
||||
}
|
||||
|
||||
reference_gemm_1(
|
||||
problem_size_1,
|
||||
alpha1,
|
||||
reference_D0.device_ref(),
|
||||
tensor_B1.device_ref(),
|
||||
beta1,
|
||||
tensor_C1.device_ref(),
|
||||
reference_D1.device_ref()
|
||||
);
|
||||
|
||||
|
||||
if(relu) {
|
||||
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
reference_D0.sync_host();
|
||||
reference_D1.sync_host();
|
||||
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
||||
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
||||
|
||||
bool passed = cutlass::reference::host::TensorEquals(
|
||||
reference_D1.host_view(),
|
||||
tensor_D1.host_view());
|
||||
|
||||
CHECK_TRUE(passed);
|
||||
if (!passed) {
|
||||
|
||||
std::stringstream fname;
|
||||
|
||||
fname << "error_B2bGemm_device_interleaved_fused.txt";
|
||||
std::cerr << "Dumping results in " << fname.str() << "\n";
|
||||
|
||||
std::ofstream file(fname.str());
|
||||
|
||||
file
|
||||
<< "A0 =\n" << tensor_A0.host_view()
|
||||
<< "\nB0 =\n" << tensor_B0.host_view()
|
||||
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
|
||||
<< "\nC0 =\n" << tensor_C0.host_view()
|
||||
// << "\nD0 =\n" << tensor_D0.host_view()
|
||||
<< "\nB1 =\n" << tensor_B1.host_view()
|
||||
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
|
||||
<< "\nC1 =\n" << tensor_C1.host_view()
|
||||
<< "\n\nReference =\n" << reference_D1.host_view()
|
||||
<< "\nComputed =\n" << tensor_D1.host_view();
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
|
@ -0,0 +1,439 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "kernel/default_b2b_gemm.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace device {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_ = ElementC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_ = arch::OpClassSimt,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_ = arch::Sm70,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::EpilogueOutputOp,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kStages,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int AlignmentA =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int AlignmentB =
|
||||
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||||
/// If true, kernel supports split-K with serial reduction
|
||||
bool SplitKSerial = false,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_ = typename DefaultGemmConfiguration<
|
||||
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||||
ElementAccumulator_>::Operator,
|
||||
/// Whether Beta is zero or not
|
||||
bool IsBetaZero = false>
|
||||
class B2bGemm {
|
||||
public:
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = LayoutA_;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = LayoutB_;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using OperatorClass = OperatorClass_;
|
||||
using ArchTag = ArchTag_;
|
||||
using ThreadblockShape0 = ThreadblockShape0_;
|
||||
using ThreadblockShape1 = ThreadblockShape1_;
|
||||
using WarpShape0 = WarpShape0_;
|
||||
using WarpShape1 = WarpShape1_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using EpilogueOutputOp0 = EpilogueOutputOp0_;
|
||||
using EpilogueOutputOp1 = EpilogueOutputOp1_;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
using Operator = Operator_;
|
||||
static int const kStages = Stages;
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
static bool const kIsBetaZero = IsBetaZero;
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Define the kernel
|
||||
using B2bGemmKernel = typename kernel::DefaultB2bGemm<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
ElementAccumulator,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
kSplitKSerial,
|
||||
Operator,
|
||||
kIsBetaZero
|
||||
>::B2bGemmKernel;
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord problem_size_0;
|
||||
GemmCoord problem_size_1;
|
||||
TensorRef<ElementA const, LayoutA> ref_A0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B0;
|
||||
TensorRef<ElementC const, LayoutC> ref_C0;
|
||||
TensorRef<ElementB const, LayoutB> ref_B1;
|
||||
TensorRef<ElementC const, LayoutC> ref_C1;
|
||||
TensorRef<ElementC, LayoutC> ref_D1;
|
||||
typename EpilogueOutputOp0::Params epilogue0;
|
||||
typename EpilogueOutputOp1::Params epilogue1;
|
||||
int split_k_slices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs an Arguments structure
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(
|
||||
GemmCoord problem_size_0_,
|
||||
GemmCoord problem_size_1_,
|
||||
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B0_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C0_,
|
||||
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||||
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||||
TensorRef<ElementC, LayoutC> ref_D1_,
|
||||
typename EpilogueOutputOp0::Params epilogue0_ =
|
||||
typename EpilogueOutputOp0::Params(),
|
||||
typename EpilogueOutputOp1::Params epilogue1_ =
|
||||
typename EpilogueOutputOp1::Params(),
|
||||
int split_k_slices_ = 1
|
||||
):
|
||||
problem_size_0(problem_size_0_),
|
||||
problem_size_1(problem_size_1_),
|
||||
ref_A0(ref_A0_),
|
||||
ref_B0(ref_B0_),
|
||||
ref_C0(ref_C0_),
|
||||
ref_B1(ref_B1_),
|
||||
ref_C1(ref_C1_),
|
||||
ref_D1(ref_D1_),
|
||||
epilogue0(epilogue0_),
|
||||
epilogue1(epilogue1_),
|
||||
split_k_slices(split_k_slices_) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel parameters object
|
||||
typename B2bGemmKernel::Params params_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the GEMM.
|
||||
B2bGemm() { }
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const &args) {
|
||||
|
||||
if (!kSplitKSerial && args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
Status status = B2bGemmKernel::can_implement(
|
||||
args.problem_size_0,
|
||||
args.problem_size_1,
|
||||
args.ref_A0.non_const_ref(),
|
||||
args.ref_B0.non_const_ref(),
|
||||
args.ref_C0.non_const_ref(),
|
||||
args.ref_B1.non_const_ref(),
|
||||
args.ref_C1.non_const_ref(),
|
||||
args.ref_D1
|
||||
);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
|
||||
size_t bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
|
||||
|
||||
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||||
}
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size_0,
|
||||
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||||
args.split_k_slices);
|
||||
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
|
||||
// args.problem_size_1,
|
||||
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
|
||||
// args.split_k_slices);
|
||||
|
||||
if (kSplitKSerial) {
|
||||
if (args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
size_t bytes = get_workspace_size(args);
|
||||
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if (args.split_k_slices > 1) {
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename B2bGemmKernel::Params{
|
||||
args.problem_size_0,
|
||||
args.problem_size_1,
|
||||
grid_shape,
|
||||
args.ref_A0.non_const_ref(),
|
||||
args.ref_B0.non_const_ref(),
|
||||
args.ref_C0.non_const_ref(),
|
||||
args.ref_B1.non_const_ref(),
|
||||
args.ref_C1.non_const_ref(),
|
||||
args.ref_D1,
|
||||
args.epilogue0,
|
||||
args.epilogue1,
|
||||
static_cast<int *>(workspace),
|
||||
};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const &args, void *workspace = nullptr) {
|
||||
|
||||
if (kSplitKSerial && args.split_k_slices > 1) {
|
||||
if (!workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
}
|
||||
|
||||
params_.ref_A0.reset(args.ref_A.non_const_ref().data());
|
||||
params_.ref_B0.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_C0.reset(args.ref_C.non_const_ref().data());
|
||||
params_.ref_B1.reset(args.ref_B.non_const_ref().data());
|
||||
params_.ref_C1.reset(args.ref_C.non_const_ref().data());
|
||||
params_.ref_D1.reset(args.ref_D.data());
|
||||
params_.output_op_0 = args.epilogue0;
|
||||
params_.output_op_1 = args.epilogue1;
|
||||
params_.semaphore = static_cast<int *>(workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr) {
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(B2bGemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
cudaError_t result;
|
||||
|
||||
int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<B2bGemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
result = cudaFuncSetAttribute(
|
||||
Kernel<B2bGemmKernel>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
result = cudaGetLastError();
|
||||
|
||||
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr) {
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(
|
||||
Arguments const &args,
|
||||
void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr) {
|
||||
|
||||
Status status = initialize(args, workspace);
|
||||
|
||||
if (status == Status::kSuccess) {
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
|
@ -0,0 +1,74 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/**
|
||||
*/
|
||||
|
||||
#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h"
|
||||
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h"
|
||||
|
||||
int run() {
|
||||
|
||||
cudaDeviceProp props;
|
||||
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!(props.major * 10 + props.minor >= 75)) {
|
||||
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
|
||||
<< std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
|
||||
run_nonfused_gemm_f16();
|
||||
run_fused_gemm_f16();
|
||||
run_nonfused_gemm_s8();
|
||||
run_fused_gemm_s8();
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
|
||||
//
|
||||
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
|
||||
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
|
||||
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
|
||||
|
||||
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return run();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,407 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
>
|
||||
struct B2bGemm {
|
||||
|
||||
using B2bMma = B2bMma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using OutputOp0 = typename B2bMma::OutputOp;
|
||||
using OutputOp1 = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount0 = typename B2bMma::WarpCount0;
|
||||
static int const kThreadCount = 32 * WarpCount0::kCount;
|
||||
|
||||
/// Parameters structure
|
||||
struct Params {
|
||||
cutlass::gemm::GemmCoord problem_size_0;
|
||||
cutlass::gemm::GemmCoord problem_size_1;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
typename B2bMma::IteratorA0::Params params_A0;
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0;
|
||||
typename B2bMma::IteratorB0::Params params_B0;
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0;
|
||||
typename Epilogue::OutputTileIterator::Params params_C0;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
|
||||
typename B2bMma::IteratorB1::Params params_B1;
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1;
|
||||
typename Epilogue::OutputTileIterator::Params params_C1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
|
||||
typename Epilogue::OutputTileIterator::Params params_D1;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
|
||||
typename OutputOp0::Params output_op_0;
|
||||
typename OutputOp1::Params output_op_1;
|
||||
int *semaphore;
|
||||
int gemm_k_iterations_0;
|
||||
int gemm_k_size_0;
|
||||
int gemm_k_iterations_1;
|
||||
int gemm_k_size_1;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(): semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
|
||||
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
cutlass::gemm::GemmCoord const & problem_size_0,
|
||||
cutlass::gemm::GemmCoord const & problem_size_1,
|
||||
cutlass::gemm::GemmCoord const & grid_tiled_shape,
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0,
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0,
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
|
||||
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
|
||||
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
|
||||
int *workspace = nullptr
|
||||
):
|
||||
problem_size_0(problem_size_0),
|
||||
problem_size_1(problem_size_1),
|
||||
grid_tiled_shape(grid_tiled_shape),
|
||||
params_A0(ref_A0.layout()),
|
||||
ref_A0(ref_A0),
|
||||
params_B0(ref_B0.layout()),
|
||||
ref_B0(ref_B0),
|
||||
params_C0(ref_C0.layout()),
|
||||
ref_C0(ref_C0),
|
||||
params_B1(ref_B1.layout()),
|
||||
ref_B1(ref_B1),
|
||||
params_C1(ref_C1.layout()),
|
||||
ref_C1(ref_C1),
|
||||
params_D1(ref_D1.layout()),
|
||||
ref_D1(ref_D1),
|
||||
output_op_0(output_op_0),
|
||||
output_op_1(output_op_1) {
|
||||
|
||||
int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
||||
int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK;
|
||||
int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
|
||||
gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK;
|
||||
|
||||
semaphore = workspace;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage {
|
||||
typename B2bMma::B2bMmaSharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
B2bGemm() { }
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(
|
||||
cutlass::gemm::GemmCoord const & problem_size_0,
|
||||
cutlass::gemm::GemmCoord const & problem_size_1,
|
||||
typename B2bMma::IteratorA0::TensorRef ref_A0,
|
||||
typename B2bMma::IteratorB0::TensorRef ref_B0,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C0,
|
||||
typename B2bMma::IteratorB1::TensorRef ref_B1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D1) {
|
||||
|
||||
static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements;
|
||||
static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_B0, kAlignmentB)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_C0, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_B1, kAlignmentB)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_C1, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(ref_D1, kAlignmentC)) {
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) ||
|
||||
(problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) ||
|
||||
(problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) ||
|
||||
(problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) ||
|
||||
(problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) ||
|
||||
(problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) {
|
||||
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
||||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A0{
|
||||
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B0{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_0,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape0::kN
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B1{
|
||||
threadblock_tile_offset.k() * params.gemm_k_size_1,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
||||
};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_0 = min(
|
||||
params.problem_size_0.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k_1 = min(
|
||||
params.problem_size_1.k(),
|
||||
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
|
||||
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename B2bMma::IteratorA0 iterator_A0(
|
||||
params.params_A0,
|
||||
params.ref_A0.data(),
|
||||
{params.problem_size_0.m(), problem_size_k_0},
|
||||
thread_idx,
|
||||
tb_offset_A0);
|
||||
|
||||
typename B2bMma::IteratorB0 iterator_B0(
|
||||
params.params_B0,
|
||||
params.ref_B0.data(),
|
||||
{problem_size_k_0, params.problem_size_0.n()},
|
||||
thread_idx,
|
||||
tb_offset_B0);
|
||||
|
||||
typename B2bMma::IteratorB1 iterator_B1(
|
||||
params.params_B1,
|
||||
params.ref_B1.data(),
|
||||
{problem_size_k_1, params.problem_size_1.n()},
|
||||
thread_idx,
|
||||
tb_offset_B1);
|
||||
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
OutputOp0 output_op_0(params.output_op_0);
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename B2bMma::FragmentC0 src_accum;
|
||||
typename B2bMma::FragmentC1 accumulators;
|
||||
|
||||
src_accum.clear();
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, iterator_B1, src_accum, output_op_0);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
OutputOp1 output_op_1(params.output_op_1);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
|
||||
|
||||
//assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * B2bMma::Shape1::kM,
|
||||
threadblock_tile_offset.n() * B2bMma::Shape1::kN
|
||||
);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op_1.set_k_partition(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C1(
|
||||
params.params_C1,
|
||||
params.ref_C1.data(),
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D1(
|
||||
params.params_D1,
|
||||
params.ref_D1.data(),
|
||||
params.problem_size_1.mn(),
|
||||
thread_idx,
|
||||
threadblock_offset
|
||||
);
|
||||
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue,
|
||||
thread_idx,
|
||||
warp_idx,
|
||||
lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k()) {
|
||||
iterator_C1 = iterator_D1;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
|
||||
__threadfence();
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else {
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
__threadfence();
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
*modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
*notice, this list of conditions and the following disclaimer in the
|
||||
*documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its
|
||||
*contributors may be used to endorse or promote products derived from this
|
||||
*software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
|
||||
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
||||
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
|
||||
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
|
||||
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_pipelined.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
|
||||
#include "kernel/b2b_gemm.h"
|
||||
#include "threadblock/default_b2b_mma.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Beta is zero or not
|
||||
bool IsBetaZero = false
|
||||
>
|
||||
struct DefaultB2bGemm;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for Turing Architecture
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// If true, kernel is configured to support serial reduction in the epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultB2bGemm<
|
||||
ElementA, LayoutA, kAlignmentA,
|
||||
ElementB, LayoutB, kAlignmentB,
|
||||
ElementC, layout::RowMajor,
|
||||
ElementAccumulator,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
EpilogueOutputOp0,
|
||||
EpilogueOutputOp1,
|
||||
ThreadblockSwizzle,
|
||||
2,
|
||||
SplitKSerial,
|
||||
Operator
|
||||
> {
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA,
|
||||
LayoutA,
|
||||
kAlignmentA,
|
||||
ElementB,
|
||||
LayoutB,
|
||||
kAlignmentB,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm75,
|
||||
ThreadblockShape0,
|
||||
ThreadblockShape1,
|
||||
WarpShape0,
|
||||
WarpShape1,
|
||||
InstructionShape,
|
||||
2,
|
||||
Operator,
|
||||
EpilogueOutputOp0
|
||||
>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
|
||||
ThreadblockShape1,
|
||||
typename B2bMma::Operator1,
|
||||
kPartitionsK1,
|
||||
EpilogueOutputOp1,
|
||||
EpilogueOutputOp1::kCount
|
||||
>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
|
||||
/// Partial specialization for Turing IMMA Interleaved layout
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp0,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp1,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of Interleaved k
|
||||
int InterleavedK,
|
||||
/// If true, kernel is configured to support serial reduction in the
|
||||
/// epilogue
|
||||
bool SplitKSerial,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Is Beta zero or not
|
||||
bool IsBetaZero>
|
||||
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
kAlignmentA, ElementB,
|
||||
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
|
||||
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
int32_t, arch::OpClassTensorOp, arch::Sm75,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
|
||||
ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {
|
||||
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
|
||||
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
|
||||
|
||||
using ElementAccumulator = int32_t;
|
||||
|
||||
/// Define the threadblock-scoped matrix multiply-accumulate
|
||||
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
|
||||
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
|
||||
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
|
||||
|
||||
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
|
||||
|
||||
/// Define the epilogue for the 2nd Gemm
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
DefaultInterleavedEpilogueTensorOp<
|
||||
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
|
||||
64 / sizeof_bits<ElementC>::value, InterleavedK,
|
||||
IsBetaZero>::Epilogue;
|
||||
|
||||
/// Define the kernel-level GEMM operator.
|
||||
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
|
@ -0,0 +1,230 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy1_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class B2bMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape0 = Shape0_;
|
||||
using Shape1 = Shape1_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy0 = Policy0_;
|
||||
using Policy1 = Policy1_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm0 = typename Policy0::Operator::Shape;
|
||||
using WarpGemm1 = typename Policy1::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount0 = GemmShape<Shape0::kM / WarpGemm0::kM,
|
||||
Shape0::kN / WarpGemm0::kN,
|
||||
Shape0::kK / WarpGemm0::kK>;
|
||||
using WarpCount1 = GemmShape<Shape1::kM / WarpGemm1::kM,
|
||||
Shape1::kN / WarpGemm1::kN,
|
||||
Shape1::kK / WarpGemm1::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations0 =
|
||||
(WarpGemm0::kK / Operator0::Policy::MmaShape::kK);
|
||||
static int const kWarpGemmIterations1 =
|
||||
(WarpGemm1::kK / Operator1::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
template<
|
||||
typename Shape_,
|
||||
typename Policy_
|
||||
>
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
using Shape = Shape_;
|
||||
using Policy = Policy_;
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages +
|
||||
Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB =
|
||||
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
};
|
||||
|
||||
using SharedStorage0 = SharedStorage<Shape0, Policy0>;
|
||||
using SharedStorage1 = SharedStorage<Shape1, Policy1>;
|
||||
union B2bMmaSharedStorage {
|
||||
SharedStorage0 sharedStorage0;
|
||||
SharedStorage1 sharedStorage1;
|
||||
};
|
||||
|
||||
|
||||
protected:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A0 operand from shared memory
|
||||
typename Operator0::IteratorA warp_tile_iterator_A0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
|
||||
typename Operator0::IteratorB warp_tile_iterator_B0_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
|
||||
typename Operator1::IteratorB warp_tile_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
B2bMmaSharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), lane_idx),
|
||||
warp_tile_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), lane_idx) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
@ -0,0 +1,509 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_base.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<int a>
|
||||
struct chk_val {
|
||||
static_assert(a==0, "check value");
|
||||
};
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape0_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA0_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA0_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB0_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB0_,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape1_,
|
||||
/// Iterates over the intermediate accumulator tile
|
||||
// (concept::MmaTensorOpFragmentIterator)
|
||||
typename FragmentIteratorA1_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB1_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB1_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
|
||||
typename OutputOp_,
|
||||
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
||||
typename Policy0_,
|
||||
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
|
||||
typename Policy1_,
|
||||
/// Transformation applied to A0 operand
|
||||
typename TransformA0_ = NumericArrayConverter<
|
||||
typename SmemIteratorA0_::Element,
|
||||
typename IteratorA0_::Element,
|
||||
IteratorA0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B0 operand
|
||||
typename TransformB0_ = NumericArrayConverter<
|
||||
typename SmemIteratorB0_::Element,
|
||||
typename IteratorB0_::Element,
|
||||
IteratorB0_::Fragment::kElements>,
|
||||
///
|
||||
/// Transformation applied to B1 operand
|
||||
typename TransformB1_ = NumericArrayConverter<
|
||||
typename SmemIteratorB1_::Element,
|
||||
typename IteratorB1_::Element,
|
||||
IteratorB1_::Fragment::kElements>,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool
|
||||
>
|
||||
class B2bMmaPipelined : public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
|
||||
public:
|
||||
|
||||
///< Base class
|
||||
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2>;
|
||||
|
||||
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy0 = Policy0_; ///< Policy describing tuning details
|
||||
|
||||
using SmemIteratorA0 = SmemIteratorA0_;
|
||||
using SmemIteratorB0 = SmemIteratorB0_;
|
||||
|
||||
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile
|
||||
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
|
||||
using Policy1 = Policy1_; ///< Policy describing tuning details
|
||||
|
||||
using SmemIteratorB1 = SmemIteratorB1_;
|
||||
|
||||
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
|
||||
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
|
||||
|
||||
using TransformA0 = TransformA0_;
|
||||
using TransformB0 = TransformB0_;
|
||||
using TransformB1 = TransformB1_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA0 = typename IteratorA0::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB0 = typename IteratorB0::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC0 = typename Policy0::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator0 = typename Policy0::Operator;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB1 = typename IteratorB1::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC1 = typename Policy1::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator1 = typename Policy1::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy0::Operator::ArchTag;
|
||||
|
||||
/// Complex transform on A0 operand
|
||||
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
|
||||
|
||||
/// Complex transform on B0 operand
|
||||
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
|
||||
|
||||
/// Complex transform on B1 operand
|
||||
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
|
||||
|
||||
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
|
||||
using WarpFragmentA0 = typename Operator0::FragmentA;
|
||||
using WarpFragmentB0 = typename Operator0::FragmentB;
|
||||
/// Warp Fragment of operand A1 loaded from accmulator tile
|
||||
using WarpFragmentA1 = typename FragmentIteratorA1::Fragment;
|
||||
using WarpFragmentB1 = typename Operator1::FragmentB;
|
||||
|
||||
protected:
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA0 smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
|
||||
SmemIteratorB0 smem_iterator_B0_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
|
||||
SmemIteratorB1 smem_iterator_B1_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
B2bMmaPipelined(
|
||||
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
|
||||
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx) {
|
||||
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
//These should stay the same across different GEMM layers
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;
|
||||
|
||||
//These may change across different GEMM layers
|
||||
int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k;
|
||||
int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n});
|
||||
this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
int gemm_k_iterations_0, ///< number of iterations of the mainloop
|
||||
FragmentC1 &accum, ///< destination accumulator tile
|
||||
IteratorA0 iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
|
||||
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
|
||||
FragmentC0 const &src_accum, ///< source accumualtor tile
|
||||
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
|
||||
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
|
||||
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
|
||||
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
FragmentC0 accum0 = src_accum;
|
||||
|
||||
FragmentA0 tb_frag_A;
|
||||
FragmentB0 tb_frag_B0;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B0.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
this->smem_iterator_A_.store(tb_frag_A);
|
||||
this->smem_iterator_B0_.store(tb_frag_B0);
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B0_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA0 warp_frag_A0[2];
|
||||
WarpFragmentB0 warp_frag_B0[2];
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
Operator0 warp_mma0;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations_0 <= 1) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
iterator_A.load(tb_frag_A);
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::WarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(tb_frag_A);
|
||||
|
||||
this->smem_iterator_B0_.store(tb_frag_B0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
iterator_A.load(tb_frag_A);
|
||||
|
||||
++this->smem_iterator_B0_;
|
||||
++this->smem_iterator_A_;
|
||||
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_A0_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
|
||||
this->warp_tile_iterator_B0_.add_tile_offset(
|
||||
{-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
|
||||
|
||||
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
|
||||
|
||||
++this->warp_tile_iterator_A0_;
|
||||
++this->warp_tile_iterator_B0_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_B0.load(tb_frag_B0);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B0;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations_0 <= 2) {
|
||||
iterator_A.clear_mask();
|
||||
iterator_B0.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);
|
||||
}
|
||||
}
|
||||
|
||||
//2nd Gemm
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
|
||||
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
FragmentB1 tb_frag_B1;
|
||||
|
||||
tb_frag_B1.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
|
||||
++iterator_B1;
|
||||
|
||||
this->smem_iterator_B1_.store(tb_frag_B1);
|
||||
|
||||
++this->smem_iterator_B1_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA1 warp_frag_A1[2];
|
||||
WarpFragmentB1 warp_frag_B1[2];
|
||||
|
||||
//warp_tile_iterator_A1_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index(0);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
Operator1 warp_mma1;
|
||||
|
||||
smem_write_stage_idx = 1;
|
||||
|
||||
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
if (gemm_k_iterations_1 <= 1) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::WarpGemmIterations == 2.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
|
||||
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
|
||||
|
||||
// Write fragments to shared memory
|
||||
|
||||
this->smem_iterator_B1_.store(tb_frag_B1);
|
||||
|
||||
__syncthreads();
|
||||
++smem_iterator_B1_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1) {
|
||||
smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else {
|
||||
this->warp_tile_iterator_B1_.add_tile_offset(
|
||||
{-Base::kStages * Policy1::kPartitionsK *
|
||||
Base::kWarpGemmIterations1,
|
||||
0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
|
||||
|
||||
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
|
||||
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
|
||||
|
||||
|
||||
++warp_tile_iterator_A1_;
|
||||
++this->warp_tile_iterator_B1_;
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
|
||||
iterator_B1.load(tb_frag_B1);
|
||||
++iterator_B1;
|
||||
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
if (gemm_k_iterations_1 <= 2) {
|
||||
iterator_B1.clear_mask();
|
||||
}
|
||||
}
|
||||
|
||||
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
|
@ -0,0 +1,289 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
|
||||
|
||||
#include "threadblock/b2b_mma_pipelined.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false>
|
||||
struct DefaultB2bMma;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1,
|
||||
WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, false> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, 2, Operator>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, 2, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>,
|
||||
ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::ColumnMajor;
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp, true>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::RowMajor,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for column-major-interleaved output
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape0,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape1,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape0,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape1,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Number of Interleaved K
|
||||
int InterleavedK>
|
||||
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
kAlignmentB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
|
||||
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
|
||||
InstructionShape, 2, Operator, EpilogueOutputOp, true> {
|
||||
// Define the MmaCore components
|
||||
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
|
||||
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
|
||||
ElementB, LayoutB, ElementAccumulator,
|
||||
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
|
||||
true>;
|
||||
|
||||
static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,
|
||||
"Alignment must match thread data map's vector length");
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>, ElementA,
|
||||
LayoutA, 1, typename MmaCore0::IteratorThreadMapA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>, ElementB,
|
||||
LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
|
||||
|
||||
// Use fragment iterator for A operand
|
||||
using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
|
||||
using FragmentIteratorA1 =
|
||||
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
|
||||
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
|
||||
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
|
||||
MmaCore1::Shape::kK, //kBlocksColumn
|
||||
ElementAccumulator, ElementA, AccumulatorLayout,
|
||||
InstructionShape, EpilogueOutputOp, true /*only handle beta=0 for 1st Gemm epilogue*/>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB1 =
|
||||
cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
|
||||
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
|
||||
|
||||
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
|
||||
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
|
||||
IteratorB0, typename MmaCore0::SmemIteratorB,
|
||||
typename MmaCore1::Shape, FragmentIteratorA1,
|
||||
IteratorB1, typename MmaCore1::SmemIteratorB,
|
||||
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
|
||||
EpilogueOutputOp,
|
||||
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
|
@ -60,6 +60,8 @@ foreach(EXAMPLE
|
|||
08_turing_tensorop_gemm
|
||||
10_planar_complex
|
||||
11_planar_complex_array
|
||||
12_gemm_bias_relu
|
||||
13_fused_two_gemms
|
||||
)
|
||||
|
||||
add_subdirectory(${EXAMPLE})
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -52,6 +52,10 @@ struct Sm72 {
|
|||
struct Sm75 {
|
||||
static int const kMinComputeCapability = 75;
|
||||
};
|
||||
struct Sm80 {
|
||||
static int const kMinComputeCapability = 80;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Directives related to cache operations
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Controls PTX cache operations
|
||||
struct CacheOperation {
|
||||
enum Kind {
|
||||
/// Cache at all levels - accessed again
|
||||
Always,
|
||||
/// Cache at global level
|
||||
Global,
|
||||
/// Streaming - likely to be accessed once
|
||||
Streaming,
|
||||
/// Indicates the line will not be used again
|
||||
LastUse,
|
||||
/// Don't cache, and fetch again
|
||||
Volatile,
|
||||
/// Write back at all coherent levels
|
||||
WriteBack,
|
||||
/// Write through to system memory
|
||||
WriteThrough
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -28,13 +28,271 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Fragment type to store loaded data
|
||||
typename AccessType,
|
||||
/// The bytes of loading
|
||||
int LoadBytes
|
||||
>
|
||||
struct global_load;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Specializations
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
struct global_load<AccessType,
|
||||
32
|
||||
> {
|
||||
CUTLASS_DEVICE
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
uint4 *data = reinterpret_cast<uint4 *>(&D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %9, 0;\n"
|
||||
" mov.b32 %0, %10;\n"
|
||||
" mov.b32 %1, %11;\n"
|
||||
" mov.b32 %2, %12;\n"
|
||||
" mov.b32 %3, %13;\n"
|
||||
" mov.b32 %4, %14;\n"
|
||||
" mov.b32 %5, %15;\n"
|
||||
" mov.b32 %6, %16;\n"
|
||||
" mov.b32 %7, %17;\n"
|
||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n"
|
||||
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n"
|
||||
"}\n"
|
||||
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w),
|
||||
"=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y),
|
||||
"r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y),
|
||||
"r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
struct global_load<AccessType,
|
||||
16
|
||||
> {
|
||||
CUTLASS_DEVICE
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
uint4 &data = reinterpret_cast<uint4 &>(D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %5, 0;\n"
|
||||
" mov.b32 %0, %6;\n"
|
||||
" mov.b32 %1, %7;\n"
|
||||
" mov.b32 %2, %8;\n"
|
||||
" mov.b32 %3, %9;\n"
|
||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
|
||||
"}\n"
|
||||
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
struct global_load<AccessType,
|
||||
8
|
||||
> {
|
||||
CUTLASS_DEVICE
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
uint2 &data = reinterpret_cast<uint2 &>(D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %3, 0;\n"
|
||||
" mov.b32 %0, %4;\n"
|
||||
" mov.b32 %1, %5;\n"
|
||||
" @p ld.global.v2.u32 {%0, %1}, [%2];\n"
|
||||
"}\n"
|
||||
: "=r"(data.x), "=r"(data.y)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
struct global_load<AccessType,
|
||||
4
|
||||
> {
|
||||
CUTLASS_DEVICE
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
unsigned &data = reinterpret_cast<unsigned &>(D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" mov.b32 %0, %3;\n"
|
||||
" @p ld.global.u32 %0, [%1];\n"
|
||||
"}\n"
|
||||
: "=r"(data)
|
||||
: "l"(ptr), "r"((int)pred_guard), "r"(data));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
struct global_load<AccessType,
|
||||
2
|
||||
> {
|
||||
CUTLASS_DEVICE
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
uint16_t &data = reinterpret_cast<uint16_t &>(D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" mov.b16 %0, %3;\n"
|
||||
" @p ld.global.u16 %0, [%1];\n"
|
||||
"}\n"
|
||||
: "=h"(data)
|
||||
: "l"(ptr), "r"((int)pred_guard), "h"(data));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType
|
||||
>
|
||||
struct global_load<AccessType,
|
||||
1
|
||||
> {
|
||||
CUTLASS_DEVICE
|
||||
global_load(AccessType &D, void const *ptr, bool pred_guard) {
|
||||
if (pred_guard) D = *(reinterpret_cast<AccessType const *>(ptr));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Fragment type to store loaded data
|
||||
typename AccessType,
|
||||
/// The bytes of loading
|
||||
int LoadBytes
|
||||
>
|
||||
struct global_store;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Specializations
|
||||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 32> {
|
||||
CUTLASS_DEVICE
|
||||
global_store(AccessType const &D, void *ptr, bool pred_guard) {
|
||||
uint4 const *data = reinterpret_cast<uint4 const *>(&D);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %5, 0;\n"
|
||||
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
|
||||
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n"
|
||||
"}\n"
|
||||
:
|
||||
: "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
|
||||
"r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16),
|
||||
"r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 16> {
|
||||
CUTLASS_DEVICE
|
||||
global_store(AccessType const &D, void *ptr, bool pred_guard) {
|
||||
uint4 const &data = reinterpret_cast<uint4 const &>(D);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %5, 0;\n"
|
||||
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
|
||||
"}\n"
|
||||
:
|
||||
: "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), "r"((int)pred_guard));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 8> {
|
||||
CUTLASS_DEVICE
|
||||
global_store(AccessType const &D, void *ptr, bool pred_guard) {
|
||||
uint2 const &data = reinterpret_cast<uint2 const &>(D);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %3, 0;\n"
|
||||
" @p st.global.v2.u32 [%0], {%1, %2};\n"
|
||||
"}\n"
|
||||
:
|
||||
: "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 4> {
|
||||
CUTLASS_DEVICE
|
||||
global_store(AccessType const &D, void *ptr, bool pred_guard) {
|
||||
uint32_t const &data = reinterpret_cast<uint32_t const &>(D);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" @p st.global.u32 [%0], %1;\n"
|
||||
"}\n"
|
||||
:
|
||||
: "l"(ptr), "r"(data), "r"((int)pred_guard));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 2> {
|
||||
CUTLASS_DEVICE
|
||||
global_store(AccessType const &D, void *ptr, bool pred_guard) {
|
||||
uint16_t const &data = reinterpret_cast<uint16_t const &>(D);
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" @p st.global.u16 [%0], %1;\n"
|
||||
"}\n"
|
||||
:
|
||||
: "l"(ptr), "h"(data), "r"((int)pred_guard));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccessType>
|
||||
struct global_store<AccessType, 1> {
|
||||
CUTLASS_DEVICE
|
||||
global_store(AccessType const &D, void *ptr, bool pred_guard) {
|
||||
if (pred_guard) *(reinterpret_cast<AccessType *>(ptr)) = D;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
@ -42,4 +300,6 @@ namespace arch {
|
|||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "memory_sm75.h"
|
||||
#include "memory_sm80.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -50,20 +50,20 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
|
|||
//
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if ! defined(CUDA_LDMATRIX_SUPPORTED)
|
||||
#define CUDA_LDMATRIX_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 2))
|
||||
#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || (__CUDACC_VER_MAJOR__ >= 11)
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
|
||||
#define CUDA_LDMATRIX_ACTIVATED 1
|
||||
#endif
|
||||
|
||||
#if ! defined(CUDA_LDMATRIX_ENABLED)
|
||||
#define CUDA_LDMATRIX_ENABLED CUDA_LDMATRIX_SUPPORTED
|
||||
#endif
|
||||
|
||||
#if CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
|
||||
#define CUDA_LDMATRIX_ACTIVATED 1
|
||||
#define CUDA_LDMATRIX_SUPPORTED 1
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) && (__CUDACC_VER_MAJOR__ > 10)
|
||||
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED 1
|
||||
#endif
|
||||
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED)
|
||||
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1))
|
||||
#endif
|
||||
|
@ -71,8 +71,9 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
|
|||
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED)
|
||||
#define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED
|
||||
#endif
|
||||
*/
|
||||
|
||||
#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
|
||||
#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
||||
extern "C" {
|
||||
//
|
||||
// This NVVM intrinsic is subject to change in future versions of CUDA.
|
||||
|
@ -85,19 +86,49 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
|
|||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
|
||||
/// CUTLASS helper to get SMEM pointer
|
||||
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
|
||||
|
||||
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
|
||||
// the previous internal intrinsics if they are available.
|
||||
#if (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
|
||||
//
|
||||
// This NVVM intrinsic converts an address in shared memory to a plain
|
||||
// unsigned integer. This is necessary to pass to shared memory instructions
|
||||
// in inline PTX.
|
||||
//
|
||||
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2].
|
||||
//
|
||||
//__device__ size_t __cvta_generic_to_shared(void* ptr);
|
||||
|
||||
/// CUTLASS helper to get SMEM pointer
|
||||
inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
|
||||
return __nvvm_get_smem_pointer(const_cast<void *>(ptr));
|
||||
}
|
||||
return static_cast<unsigned>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
/// CUTLASS helper to get SMEM pointer
|
||||
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
|
||||
return __nvvm_get_smem_pointer(ptr);
|
||||
}
|
||||
#elif (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
|
||||
|
||||
return __nvvm_get_smem_pointer(ptr);
|
||||
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
|
||||
uint32_t smem_ptr;
|
||||
|
||||
asm(
|
||||
"{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr) : "l"(ptr));
|
||||
|
||||
return smem_ptr;
|
||||
|
||||
#else
|
||||
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
/// CUTLASS helper to get SMEM pointer
|
||||
inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
|
||||
return cutlass_get_smem_pointer(const_cast<void *>(ptr));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
|
@ -235,5 +266,6 @@ inline __device__ void ldsm<layout::ColumnMajor, 4>(
|
|||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Architecture-specific operators on memory added for SM80
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
#define CUDA_CP_ASYNC_ACTIVATED 1
|
||||
#else
|
||||
#define CUDA_CP_ASYNC_ACTIVATED 0
|
||||
#endif
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Initiates an asynchronous copy from global memory to shared memory.
|
||||
///
|
||||
/// LDGSTS
|
||||
///
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes,
|
||||
/// Cache operation
|
||||
CacheOperation::Kind cache_op = CacheOperation::Always>
|
||||
struct cp_async;
|
||||
|
||||
/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
|
||||
/// the entire transfer, zeros are written to SMEM if the guard predicate is false.
|
||||
///
|
||||
/// LDGSTS
|
||||
///
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes,
|
||||
/// Cache operation
|
||||
CacheOperation::Kind cache_op = CacheOperation::Always>
|
||||
struct cp_async_zfill;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct cp_async<SizeInBytes, CacheOperation::Always> {
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
|
||||
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred_guard),
|
||||
"r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
|
||||
|
||||
#else
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct cp_async_zfill<SizeInBytes, CacheOperation::Always> {
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
|
||||
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
|
||||
int src_in_bytes = (pred_guard ? SizeInBytes : 0);
|
||||
|
||||
asm volatile(
|
||||
"cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
||||
"l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes));
|
||||
|
||||
#else
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct cp_async<SizeInBytes, CacheOperation::Global> {
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
|
||||
static_assert(SizeInBytes == 16,
|
||||
"cp.async only supports CacheOperation::Global when access size is 16B.");
|
||||
|
||||
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred_guard),
|
||||
"r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
|
||||
|
||||
#else
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct cp_async_zfill<SizeInBytes, CacheOperation::Global> {
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
|
||||
static_assert(SizeInBytes == 16,
|
||||
"cp.async only supports CacheOperation::Global when access size is 16B.");
|
||||
|
||||
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
|
||||
int src_in_bytes = (pred_guard ? SizeInBytes : 0);
|
||||
|
||||
asm volatile(
|
||||
"cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
||||
"l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes));
|
||||
|
||||
#else
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
|
||||
CUTLASS_DEVICE
|
||||
void cp_async_fence() {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
|
||||
template <int N>
|
||||
CUTLASS_DEVICE void cp_async_wait() {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Blocks until all previous cp.async.commit_group operations have committed.
|
||||
template <>
|
||||
CUTLASS_DEVICE void cp_async_wait<0>() {
|
||||
#if CUDA_CP_ASYNC_ACTIVATED
|
||||
asm volatile("cp.async.wait_all;\n" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -51,11 +51,26 @@ struct OpMultiplyAddSaturate;
|
|||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tag indicating the input is converted to a narrower type (BF16)
|
||||
struct OpMultiplyAddFastBF16;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tag indicating the input is converted to a narrower type (F16)
|
||||
struct OpMultiplyAddFastF16;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tag indicating the complex multiply-add operation
|
||||
struct OpMultiplyAddComplex;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tag indicating the gaussian complex multiply-add operation
|
||||
struct OpMultiplyAddGaussianComplex;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tag indicating the inner product is defined by (XOR, POPC)
|
||||
struct OpXorPopc;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -120,8 +120,7 @@ struct Wmma<
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// WMMA template structure defines nvcuda::wmma::fragments and static assert for
|
||||
// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1)
|
||||
// (nvcuda::wmma targetting SASS instruction BMMA)
|
||||
// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1).
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -167,7 +167,7 @@ public:
|
|||
class const_iterator {
|
||||
|
||||
/// Pointer to object
|
||||
T *ptr_;
|
||||
const T *ptr_;
|
||||
|
||||
public:
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -0,0 +1,461 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Defines a proxy class for storing non-standard 16-bit floating point values with
|
||||
8 bits of exponent and 7 bit of mantissa.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <cstdint>
|
||||
#endif
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Floating-point type with 8 bits of exponent and 7 bits of mantissa.
|
||||
struct alignas(2) bfloat16_t {
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Storage type
|
||||
uint16_t storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs from an unsigned short
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bfloat16_t bitcast(uint16_t x) {
|
||||
bfloat16_t h;
|
||||
h.storage = x;
|
||||
return h;
|
||||
}
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t() { }
|
||||
|
||||
/// Floating-point conversion - round toward nearest
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit bfloat16_t(float x) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||
|
||||
asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x));
|
||||
|
||||
#else
|
||||
uint32_t bits = reinterpret_cast<uint32_t &>(x);
|
||||
|
||||
if ((bits & 0x7f800000) != 0x7f800000) {
|
||||
|
||||
bool mantissa_bit = ((bits & (1 << 16)) != 0);
|
||||
bool round_bit = ((bits & (1 << 15)) != 0);
|
||||
bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0);
|
||||
|
||||
if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) {
|
||||
bits += uint32_t(1 << 16);
|
||||
}
|
||||
}
|
||||
else if (bits & ~0xff800000) {
|
||||
bits = 0x7fffffff;
|
||||
}
|
||||
|
||||
storage = uint16_t((bits >> 16) & 0xffff);
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Floating-point conversion - round toward nearest
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit bfloat16_t(double x): bfloat16_t(float(x)) {
|
||||
|
||||
}
|
||||
|
||||
/// Integer conversion - round toward nearest
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit bfloat16_t(int x) {
|
||||
float flt = static_cast<float>(x);
|
||||
storage = uint16_t(reinterpret_cast<uint32_t const &>(flt) >> 16);
|
||||
}
|
||||
|
||||
/// Converts to float
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator float() const {
|
||||
unsigned bits = (unsigned(storage) << 16);
|
||||
return reinterpret_cast<float const &>(bits);
|
||||
}
|
||||
|
||||
/// Converts to float
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator double() const {
|
||||
return double(float(*this));
|
||||
}
|
||||
|
||||
/// Converts to int
|
||||
CUTLASS_HOST_DEVICE
|
||||
explicit operator int() const {
|
||||
return int(float(*this));
|
||||
}
|
||||
|
||||
/// Casts to bool
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator bool() const {
|
||||
return (float(*this) != 0.0f);
|
||||
}
|
||||
|
||||
/// Obtains raw bits
|
||||
CUTLASS_HOST_DEVICE
|
||||
uint16_t raw() const {
|
||||
return storage;
|
||||
}
|
||||
/// Returns the sign bit
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool signbit() const {
|
||||
return ((raw() & 0x8000) != 0);
|
||||
}
|
||||
|
||||
/// Returns the biased exponent
|
||||
CUTLASS_HOST_DEVICE
|
||||
int exponent_biased() const {
|
||||
return int((raw() >> 7) & 0x0ff);
|
||||
}
|
||||
|
||||
/// Returns the unbiased exponent
|
||||
CUTLASS_HOST_DEVICE
|
||||
int exponent() const {
|
||||
return exponent_biased() - 127;
|
||||
}
|
||||
|
||||
/// Returns the mantissa
|
||||
CUTLASS_HOST_DEVICE
|
||||
int mantissa() const {
|
||||
return int(raw() & 0x7f);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool signbit(cutlass::bfloat16_t const& h) {
|
||||
return h.signbit();
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) {
|
||||
return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fffffff);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool isnan(cutlass::bfloat16_t const& h) {
|
||||
return (h.exponent_biased() == 0x0ff) && h.mantissa();
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool isfinite(cutlass::bfloat16_t const& h) {
|
||||
return (h.exponent_biased() != 0x0ff);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::bfloat16_t nan_bf16(const char*) {
|
||||
// NVIDIA canonical NaN
|
||||
return cutlass::bfloat16_t::bitcast(0x7fff);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool isinf(cutlass::bfloat16_t const& h) {
|
||||
return (h.exponent_biased() == 0x0ff) && !h.mantissa();
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool isnormal(cutlass::bfloat16_t const& h) {
|
||||
return h.exponent_biased() && h.exponent_biased() != 0x0ff;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int fpclassify(cutlass::bfloat16_t const& h) {
|
||||
int exp = h.exponent_biased();
|
||||
int mantissa = h.mantissa();
|
||||
if (exp == 0x0ff) {
|
||||
if (mantissa) {
|
||||
return FP_NAN;
|
||||
}
|
||||
else {
|
||||
return FP_INFINITE;
|
||||
}
|
||||
}
|
||||
else if (!exp) {
|
||||
if (mantissa) {
|
||||
return FP_SUBNORMAL;
|
||||
}
|
||||
else {
|
||||
return FP_ZERO;
|
||||
}
|
||||
}
|
||||
return FP_NORMAL;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) {
|
||||
#if defined(__CUDACC_RTC__)
|
||||
return cutlass::bfloat16_t(sqrtf(float(h)));
|
||||
#else
|
||||
return cutlass::bfloat16_t(std::sqrt(float(h)));
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) {
|
||||
|
||||
uint16_t a_mag = (reinterpret_cast<uint16_t const &>(a) & 0x7fff);
|
||||
uint16_t b_sign = (reinterpret_cast<uint16_t const &>(b) & 0x8000);
|
||||
uint16_t result = (a_mag | b_sign);
|
||||
|
||||
return reinterpret_cast<bfloat16_t const &>(result);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Standard Library operations and definitions
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace std {
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Numeric limits
|
||||
template <>
|
||||
struct numeric_limits<cutlass::bfloat16_t> {
|
||||
static bool const is_specialized = true;
|
||||
static bool const is_signed = true;
|
||||
static bool const is_integer = false;
|
||||
static bool const is_exact = false;
|
||||
static bool const has_infinity = true;
|
||||
static bool const has_quiet_NaN = true;
|
||||
static bool const has_signaling_NaN = false;
|
||||
static std::float_denorm_style const has_denorm = std::denorm_present;
|
||||
static bool const has_denorm_loss = true;
|
||||
static std::float_round_style const round_style = std::round_to_nearest;
|
||||
static bool const is_iec559 = false;
|
||||
static bool const is_bounded = true;
|
||||
static bool const is_modulo = false;
|
||||
static int const digits = 7;
|
||||
|
||||
/// Least positive value
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); }
|
||||
|
||||
/// Minimum finite value
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); }
|
||||
|
||||
/// Maximum finite value
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
|
||||
|
||||
/// Returns smallest finite value
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); }
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace std
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Arithmetic operators
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return float(lhs) == float(rhs);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return float(lhs) != float(rhs);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return float(lhs) < float(rhs);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return float(lhs) <= float(rhs);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return float(lhs) > float(rhs);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return float(lhs) >= float(rhs);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return bfloat16_t(float(lhs) + float(rhs));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t operator-(bfloat16_t const& lhs) {
|
||||
return bfloat16_t(-float(lhs));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return bfloat16_t(float(lhs) - float(rhs));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return bfloat16_t(float(lhs) * float(rhs));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) {
|
||||
return bfloat16_t(float(lhs) / float(rhs));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) {
|
||||
lhs = bfloat16_t(float(lhs) + float(rhs));
|
||||
return lhs;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) {
|
||||
lhs = bfloat16_t(float(lhs) - float(rhs));
|
||||
return lhs;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) {
|
||||
lhs = bfloat16_t(float(lhs) * float(rhs));
|
||||
return lhs;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) {
|
||||
lhs = bfloat16_t(float(lhs) / float(rhs));
|
||||
return lhs;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t& operator++(bfloat16_t & lhs) {
|
||||
float tmp(lhs);
|
||||
++tmp;
|
||||
lhs = bfloat16_t(tmp);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t& operator--(bfloat16_t & lhs) {
|
||||
float tmp(lhs);
|
||||
--tmp;
|
||||
lhs = bfloat16_t(tmp);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t operator++(bfloat16_t & lhs, int) {
|
||||
bfloat16_t ret(lhs);
|
||||
float tmp(lhs);
|
||||
tmp++;
|
||||
lhs = bfloat16_t(tmp);
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bfloat16_t operator--(bfloat16_t & lhs, int) {
|
||||
bfloat16_t ret(lhs);
|
||||
float tmp(lhs);
|
||||
tmp--;
|
||||
lhs = bfloat16_t(tmp);
|
||||
return ret;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// User-defined literals
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::bfloat16_t operator "" _bf16(long double x) {
|
||||
return cutlass::bfloat16_t(float(x));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) {
|
||||
return cutlass::bfloat16_t(int(x));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -35,6 +35,9 @@
|
|||
#include "cutlass/half.h"
|
||||
#include "cutlass/real.h"
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/tfloat32.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <iosfwd>
|
||||
#endif
|
||||
|
@ -370,6 +373,15 @@ template <typename T>
|
|||
CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const &z) {
|
||||
return complex<T>(real(z), -imag(z));
|
||||
}
|
||||
/// Indentity transform for non-complex types
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T conj(T const &z) {
|
||||
static_assert( !std::is_same<T, cuComplex>::value &&
|
||||
!std::is_same<T, cuDoubleComplex>::value &&
|
||||
!std::is_same<T, cutlass::complex<double>>::value &&
|
||||
!std::is_same<T, cutlass::complex<float>>::value, "May not be a complex data type");
|
||||
return z;
|
||||
}
|
||||
|
||||
/// Projects the complex number z onto the Riemann sphere
|
||||
template <typename T>
|
||||
|
@ -429,6 +441,7 @@ template <typename T>
|
|||
struct RealType< complex<T> > {
|
||||
using Type = T;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static complex<T> from_real(double x) {
|
||||
return complex<T>(static_cast<T>(x));
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -360,6 +360,29 @@ public:
|
|||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/// Scalar multiplication
|
||||
template <int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator*(Index s, Coord<Rank, Index> coord) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] *= s;
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
/// Scalar multiplication
|
||||
template <int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank, Index> operator*(Coord<Rank, Index> coord, Index s) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
coord[i] *= s;
|
||||
}
|
||||
return coord;
|
||||
}
|
||||
|
||||
/// Scalar division
|
||||
template <int Rank, typename Index>
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
@ -419,3 +442,4 @@ Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -33,9 +33,14 @@
|
|||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// stream operators for cutlass namespace //
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Rank>
|
||||
|
@ -47,8 +52,6 @@ std::ostream& operator<<(std::ostream& out, Coord<Rank> const& coord) {
|
|||
return out;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline
|
||||
std::istream & operator>>(std::istream &stream, half_t &x) {
|
||||
float tmp;
|
||||
|
@ -62,6 +65,16 @@ std::ostream & operator<<(std::ostream &out, half_t const &x) {
|
|||
return out << float(x);
|
||||
}
|
||||
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, bfloat16_t const &x) {
|
||||
return out << float(x);
|
||||
}
|
||||
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) {
|
||||
return out << float(x);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to enable formatted printing of CUTLASS scalar types to an ostream
|
||||
|
@ -98,7 +111,54 @@ inline std::ostream &operator<<(std::ostream &out, ScalarIO<uint8_t> const &scal
|
|||
return out << unsigned(scalar.value);
|
||||
}
|
||||
|
||||
|
||||
/// Default printing to ostream for MatrixShape
|
||||
template <int Row, int Column>
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::MatrixShape<Row, Column> const &matrix_shape) {
|
||||
out << "cutlass::MatrixShape::(kRow, kColumn) {"
|
||||
<< cutlass::MatrixShape<Row,Column>::kRow <<","
|
||||
<< cutlass::MatrixShape<Row,Column>::kColumn <<"}";
|
||||
return out;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// stream operators for cutlass::gemm namespace //
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
namespace gemm {
|
||||
|
||||
/// Default printing to ostream for GemmShape
|
||||
template <int M, int N, int K>
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::gemm::GemmShape<M,N,K> const &gemm_shape) {
|
||||
out << "cutlass::GemmShape::(kM, kN, kK) {"
|
||||
<< cutlass::gemm::GemmShape<M,N,K>::kM <<","
|
||||
<< cutlass::gemm::GemmShape<M,N,K>::kN <<","
|
||||
<< cutlass::gemm::GemmShape<M,N,K>::kK << "}";
|
||||
return out;
|
||||
}
|
||||
|
||||
} //namespace gemm
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// stream operators for cutlass::layout namespace //
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
namespace layout {
|
||||
|
||||
/// Default printing to ostream for PitchLinearShape
|
||||
template < int Contiguous, int Strided>
|
||||
inline
|
||||
std::ostream & operator<<(std::ostream &out, cutlass::layout::PitchLinearShape<Contiguous, Strided> const &pitch_linear_shape) {
|
||||
out << "cutlass::layout::PitchLinearShape::(kContiguous, kStrided) {"
|
||||
<< cutlass::layout::PitchLinearShape<Contiguous,Strided>::kContiguous <<","
|
||||
<< cutlass::layout::PitchLinearShape<Contiguous,Strided>::kStrided <<"}";
|
||||
return out;
|
||||
}
|
||||
|
||||
} //namespace layout
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -55,6 +55,8 @@ enum class Status {
|
|||
kErrorNotSupported, ///< Operation is not supported on current device.
|
||||
kErrorWorkspaceNull, ///< The given workspace is null when it is required to be non-null.
|
||||
kErrorInternal, ///< An error within CUTLASS occurred.
|
||||
kErrorArchMismatch, ///< CUTLASS runs on a device that it was not compiled for.
|
||||
kErrorInsufficientDriver, ///< CUTLASS runs with a driver that is too old.
|
||||
kInvalid ///< Status is unspecified.
|
||||
};
|
||||
|
||||
|
@ -78,6 +80,10 @@ static char const* cutlassGetStatusString(cutlass::Status status) {
|
|||
return "Error Workspace Null";
|
||||
case cutlass::Status::kErrorInternal:
|
||||
return "Error Internal";
|
||||
case cutlass::Status::kErrorInsufficientDriver:
|
||||
return "Error Insufficient Driver";
|
||||
case cutlass::Status::kErrorArchMismatch:
|
||||
return "Erroor Architecture Mismatch";
|
||||
case cutlass::Status::kInvalid: break;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief This extends the contents of cutlass/functional.h with frequently used activation functions.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/functional.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// ReLu operator - propagates NaNs
|
||||
template <typename T>
|
||||
struct ReLu {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const & threshold, T const &value) const {
|
||||
if (value < threshold) {
|
||||
value = threshold;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct ReLu<Array<T, N>> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
|
||||
Array<T, N> result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
T value = frag[i];
|
||||
if (value < threshold) {
|
||||
value = threshold;
|
||||
}
|
||||
result[i] = value;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Sigmoid operator
|
||||
template <typename T>
|
||||
struct Sigmoid {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &scalar) const {
|
||||
return T(1) / (T(1) + exp(-scalar));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Sigmoid<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
float operator()(float const &scalar) const {
|
||||
return 1.0f / (1.0f + expf(-scalar));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct Sigmoid<Array<T, N> > {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
||||
Array<T, N> y;
|
||||
Sigmoid<T> sigmoid_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < int(rhs.size()); ++i) {
|
||||
y[i] = sigmoid_op(rhs[i]);
|
||||
}
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -101,7 +101,7 @@ public:
|
|||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentOutput const &source,
|
||||
FragmentOutput const &source = FragmentOutput(),
|
||||
ElementCompute uniform = ElementCompute(0)) const {
|
||||
|
||||
// Convert to destination numeric type
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -165,6 +165,28 @@ public:
|
|||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -178,6 +178,40 @@ public:
|
|||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
|
||||
minimum<ComputeFragment> min_accumulator;
|
||||
maximum<ComputeFragment> max_accumulator;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
/// Clamping constant value
|
||||
ElementCompute const kClamp =
|
||||
ElementCompute((1U << (sizeof_bits<ElementOutput>::value - 1)) - 1);
|
||||
|
||||
intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1));
|
||||
intermediate = min_accumulator(intermediate, kClamp);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -278,7 +312,7 @@ public:
|
|||
beta_ = ElementCompute(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
|
@ -316,6 +350,37 @@ public:
|
|||
|
||||
return destination_converter(scaled_accumulator);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Compute linear scaling in floating point
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_accumulator;
|
||||
|
||||
// Float min-max
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
|
||||
}
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(scaled_accumulator);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // Conditional guards to enable partial specialization for packed integers
|
||||
|
@ -410,7 +475,7 @@ class FastLinearCombinationClamp {
|
|||
beta_ = ElementCompute(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &accumulator,
|
||||
|
@ -453,6 +518,41 @@ class FastLinearCombinationClamp {
|
|||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
FastNumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Compute linear scaling in floating point
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
|
||||
minimum<ComputeFragment> min_accumulator;
|
||||
maximum<ComputeFragment> max_accumulator;
|
||||
|
||||
// Float min-max
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator);
|
||||
|
||||
/// Clamping constant value
|
||||
ElementCompute const kClamp =
|
||||
ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
|
||||
|
||||
intermediate = max_accumulator(intermediate, -kClamp);
|
||||
intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
|
||||
|
||||
// Convert to destination numeric type
|
||||
FastNumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
|
||||
destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -185,6 +185,39 @@ public:
|
|||
destination_converter(intermediate.real),
|
||||
destination_converter(intermediate.imag));
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator(
|
||||
accumulator_converter(accumulator.real),
|
||||
accumulator_converter(accumulator.imag));
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<Array<ElementCompute, kCount> > mul_op;
|
||||
multiply_add<Array<ElementCompute, kCount> > mul_add_op;
|
||||
|
||||
// complex multiply-add: I = alpha * AB + I
|
||||
intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real);
|
||||
intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag);
|
||||
|
||||
intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real);
|
||||
intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return FragmentOutput(
|
||||
destination_converter(intermediate.real),
|
||||
destination_converter(intermediate.imag));
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -23,8 +23,7 @@
|
|||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination operations used by epilogues. Values are clamped before
|
||||
converting to the output element type.
|
||||
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
@ -34,6 +33,7 @@
|
|||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
@ -43,8 +43,7 @@ namespace thread {
|
|||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements then clamps the output before
|
||||
/// converting to the output element type.
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
|
@ -75,10 +74,10 @@ public:
|
|||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute threshold; ///< Relu threshold
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
|
||||
ElementCompute const *threshold_ptr; ///< pointer to threshold scalar - if not null, loads from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
@ -87,16 +86,17 @@ public:
|
|||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
threshold(ElementCompute(0)),
|
||||
threshold(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
beta_ptr(nullptr),
|
||||
threshold_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr), threshold_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
|
@ -104,8 +104,8 @@ public:
|
|||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
ElementCompute const *threshold_ptr = nullptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold_ptr(threshold_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
@ -128,7 +128,7 @@ public:
|
|||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
threshold_ = params.threshold;
|
||||
threshold_ = (params.threshold_ptr ? *params.threshold_ptr : params.threshold);
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
|
@ -144,13 +144,12 @@ public:
|
|||
beta_ = ElementCompute(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentOutput const &source,
|
||||
ElementCompute uniform = ElementCompute(0)) const {
|
||||
FragmentOutput const &source) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
||||
|
@ -160,18 +159,44 @@ public:
|
|||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
|
||||
maximum<ComputeFragment> max_accumulator;
|
||||
ReLu<ComputeFragment> relu;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
intermediate = max_accumulator(intermediate, threshold_);
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
ReLu<ComputeFragment> relu;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
// Compute threshold optionally
|
||||
intermediate = relu(threshold_, intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
@ -180,24 +205,24 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Conditional guards to enable partial specialization for packed integers
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
|
||||
((__CUDACC_VER_MAJOR__ > 10) || \
|
||||
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
|
||||
|
||||
/// Applies a linear combination operator to an array of elements then clamps the output before
|
||||
/// converting to the output element type.
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
/// Special handling for int types
|
||||
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
class LinearCombinationRelu<ElementOutput_, Count, int, float, Round> {
|
||||
class LinearCombinationRelu <ElementOutput_, Count, int, float, Round> {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementOutput_;
|
||||
|
@ -217,10 +242,10 @@ public:
|
|||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute threshold; ///< Relu threshold
|
||||
ElementCompute threshold; ///< minimum value that is output
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
|
||||
ElementCompute const *threshold_ptr; ///< pointer to threshold scalar - if not null, loads from memory
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
@ -229,16 +254,17 @@ public:
|
|||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
threshold(ElementCompute(0)),
|
||||
threshold(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
beta_ptr(nullptr),
|
||||
threshold_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr), threshold_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
|
@ -246,8 +272,8 @@ public:
|
|||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr,
|
||||
ElementCompute threshold = ElementCompute(0)
|
||||
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
ElementCompute const *threshold_ptr = nullptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold_ptr(threshold_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
@ -270,7 +296,7 @@ public:
|
|||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
threshold_ = params.threshold;
|
||||
threshold_ = (params.threshold_ptr ? *params.threshold_ptr : params.threshold);
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
|
@ -286,13 +312,12 @@ public:
|
|||
beta_ = ElementCompute(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentOutput const &source,
|
||||
ElementCompute uniform = ElementCompute(0)) const {
|
||||
FragmentOutput const &source) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
||||
|
@ -302,21 +327,16 @@ public:
|
|||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
|
||||
maximum<ComputeFragment> max_accumulator;
|
||||
ReLu<FragmentAccumulator> relu;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
// Clamp to theshold
|
||||
intermediate = max_accumulator(intermediate, threshold_);
|
||||
|
||||
// Convert back to accumulator data type
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
|
@ -324,8 +344,46 @@ public:
|
|||
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
|
||||
}
|
||||
|
||||
// Convert to destination numeric type and pack
|
||||
NumericArrayConverter<ElementOutput, ElementAccumulator, kCount, Round> destination_converter;
|
||||
// Compute threshold optionally
|
||||
scaled_accumulator = relu(threshold_, scaled_accumulator);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(scaled_accumulator);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_accumulator;
|
||||
ReLu<FragmentAccumulator> relu;
|
||||
|
||||
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
// Convert floats back to INT
|
||||
FragmentAccumulator scaled_accumulator;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
|
||||
}
|
||||
|
||||
// Compute threshold optionally
|
||||
scaled_accumulator = relu(threshold_, scaled_accumulator);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(scaled_accumulator);
|
||||
}
|
||||
|
@ -338,3 +396,6 @@ public:
|
|||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination operations used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Applies a linear combination operator to an array of elements.
|
||||
///
|
||||
/// D = alpha * accumulator + beta * source + uniform
|
||||
///
|
||||
template <
|
||||
typename ElementOutput_, ///< Data type used to load and store tensors
|
||||
int Count, ///< Number of elements computed per operation
|
||||
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
|
||||
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
|
||||
>
|
||||
class LinearCombinationSigmoid {
|
||||
public:
|
||||
|
||||
using ElementOutput = ElementOutput_;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementCompute = ElementCompute_;
|
||||
|
||||
static int const kCount = Count;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
|
||||
using ComputeFragment = Array<ElementCompute, kCount>;
|
||||
|
||||
static FloatRoundStyle const kRound = Round;
|
||||
|
||||
/// Host-constructable parameters structure
|
||||
struct Params {
|
||||
|
||||
ElementCompute alpha; ///< scales accumulators
|
||||
ElementCompute beta; ///< scales source tensor
|
||||
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
|
||||
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params():
|
||||
alpha(ElementCompute(1)),
|
||||
beta(ElementCompute(0)),
|
||||
alpha_ptr(nullptr),
|
||||
beta_ptr(nullptr) { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute alpha,
|
||||
ElementCompute beta
|
||||
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
|
||||
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
ElementCompute const *alpha_ptr,
|
||||
ElementCompute const *beta_ptr
|
||||
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
ElementCompute alpha_;
|
||||
ElementCompute beta_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs the function object, possibly loading from pointers in host memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
LinearCombinationSigmoid(Params const ¶ms) {
|
||||
|
||||
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
|
||||
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
|
||||
}
|
||||
|
||||
/// Returns true if source is needed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_source_needed() const {
|
||||
return beta_ != ElementCompute(0);
|
||||
}
|
||||
|
||||
/// Functionally required for serial reduction in the epilogue
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set_k_partition(int k_partition) {
|
||||
if (k_partition) {
|
||||
beta_ = ElementCompute(1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator + beta * source
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator,
|
||||
FragmentOutput const &source) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_source = source_converter(source);
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_source;
|
||||
multiply_add<ComputeFragment> mul_add_accumulator;
|
||||
Sigmoid<ComputeFragment> sigmoid;
|
||||
|
||||
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
|
||||
|
||||
intermediate = sigmoid(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
|
||||
/// Computes linear scaling: D = alpha * accumulator
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(
|
||||
FragmentAccumulator const &accumulator) const {
|
||||
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
|
||||
|
||||
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
|
||||
|
||||
// Perform binary operations
|
||||
|
||||
ComputeFragment intermediate;
|
||||
|
||||
multiplies<ComputeFragment> mul_add_accumulator;
|
||||
Sigmoid<ComputeFragment> sigmoid;
|
||||
|
||||
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
|
||||
|
||||
intermediate = sigmoid(intermediate);
|
||||
|
||||
// Convert to destination numeric type
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
|
||||
|
||||
return destination_converter(intermediate);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -45,6 +45,7 @@
|
|||
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
||||
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
|
@ -76,6 +77,7 @@ template <
|
|||
/// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load()
|
||||
int ElementsPerAccess,
|
||||
/// Multiply-add operator
|
||||
/// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
|
||||
typename Operator_ = arch::OpMultiplyAddComplex>
|
||||
struct DefaultEpilogueComplexTensorOp {
|
||||
|
||||
|
@ -146,6 +148,91 @@ struct DefaultEpilogueComplexTensorOp {
|
|||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Partial specialization and defines sensible defaults for epilogues for complex*complex case
|
||||
// 3 real-valued mma operations (Gaussian Complex)
|
||||
// A = (ar + j ai), B = (br +j bi), D = AB
|
||||
// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi)
|
||||
// D = dr + j di = (P1 - P3) + j (P1 + P2)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
typename Shape_,
|
||||
typename WarpMmaTensorOp_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpilogueComplexTensorOp <Shape_, WarpMmaTensorOp_, PartitionsK,
|
||||
OutputOp_, ElementsPerAccess,
|
||||
arch::OpMultiplyAddGaussianComplex> {
|
||||
|
||||
using Shape = Shape_;
|
||||
using WarpMmaTensorOp = WarpMmaTensorOp_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using OutputOp = OutputOp_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
using Operator = arch::OpMultiplyAddGaussianComplex;
|
||||
|
||||
using ElementOutput = typename OutputOp::ElementOutput;
|
||||
using LayoutC = typename WarpMmaTensorOp::LayoutC;
|
||||
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
|
||||
|
||||
//
|
||||
// Thread map
|
||||
//
|
||||
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
||||
Shape,
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
kPartitionsK,
|
||||
ElementOutput,
|
||||
kElementsPerAccess
|
||||
>::Type;
|
||||
|
||||
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
OutputTileThreadMap,
|
||||
ElementOutput
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
LayoutC
|
||||
>;
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
ElementAccumulator,
|
||||
LayoutC
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
||||
typename OutputTileThreadMap::CompactedThreadMap,
|
||||
ElementAccumulator
|
||||
>;
|
||||
|
||||
/// Hard-coded padding elements added
|
||||
using Padding = cutlass::MatrixShape<0, 0>;
|
||||
|
||||
//
|
||||
// Define the epilogue
|
||||
//
|
||||
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
|
||||
Shape,
|
||||
WarpMmaTensorOp,
|
||||
kPartitionsK,
|
||||
OutputTileIterator,
|
||||
AccumulatorFragmentIterator,
|
||||
WarpTileIterator,
|
||||
SharedLoadIterator,
|
||||
OutputOp,
|
||||
Padding
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -148,6 +148,44 @@ struct DefaultEpiloguePlanarComplex<
|
|||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues.
|
||||
template <
|
||||
typename ThreadblockShape_,
|
||||
typename WarpMmaOperator_,
|
||||
int PartitionsK,
|
||||
typename OutputOp_,
|
||||
int ElementsPerAccess
|
||||
>
|
||||
struct DefaultEpiloguePlanarComplex<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm80,
|
||||
PartitionsK,
|
||||
OutputOp_,
|
||||
ElementsPerAccess> {
|
||||
|
||||
using RealEpilogue = DefaultEpilogueTensorOp<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
PartitionsK,
|
||||
OutputOp_,
|
||||
ElementsPerAccess
|
||||
>;
|
||||
|
||||
using Epilogue = EpiloguePlanarComplex<
|
||||
ThreadblockShape_,
|
||||
WarpMmaOperator_,
|
||||
PartitionsK,
|
||||
typename RealEpilogue::OutputTileIterator,
|
||||
typename RealEpilogue::AccumulatorFragmentIterator,
|
||||
typename RealEpilogue::WarpTileIterator,
|
||||
typename RealEpilogue::SharedLoadIterator,
|
||||
OutputOp_,
|
||||
typename RealEpilogue::Padding
|
||||
>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -39,6 +39,7 @@
|
|||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -39,16 +39,20 @@
|
|||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
||||
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
||||
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
||||
|
@ -61,6 +65,177 @@ namespace threadblock {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <
|
||||
typename ElementOutput,
|
||||
typename ElementAccumulator,
|
||||
int ElementsPerAccess,
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename ThreadMap
|
||||
>
|
||||
struct DefaultIteratorsTensorOp {
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ElementAccumulator,
|
||||
layout::RowMajor
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
||||
ThreadMap,
|
||||
ElementAccumulator
|
||||
>;
|
||||
};
|
||||
|
||||
/// Partial specialization for half <= float x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename ThreadMap
|
||||
>
|
||||
struct DefaultIteratorsTensorOp<
|
||||
half_t,
|
||||
float,
|
||||
8,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ThreadMap> {
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
float,
|
||||
32,
|
||||
16,
|
||||
8,
|
||||
8
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
||||
ThreadMap,
|
||||
float,
|
||||
32,
|
||||
16,
|
||||
8,
|
||||
8
|
||||
>;
|
||||
};
|
||||
|
||||
/// Partial specialization for int8_t x 16 <= int32_t x 16 epilogues avoids shared memory bank conflicts.
|
||||
template <
|
||||
int K,
|
||||
typename InstructionShape,
|
||||
typename ThreadMap
|
||||
>
|
||||
struct DefaultIteratorsTensorOp<
|
||||
int8_t,
|
||||
int32_t,
|
||||
16,
|
||||
gemm::GemmShape<128, 128, K>,
|
||||
gemm::GemmShape<64, 64, K>,
|
||||
InstructionShape,
|
||||
ThreadMap> {
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
||||
gemm::GemmShape<64, 64, K>,
|
||||
InstructionShape,
|
||||
int32_t,
|
||||
32,
|
||||
8,
|
||||
16,
|
||||
8
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
||||
ThreadMap,
|
||||
int32_t,
|
||||
32,
|
||||
8,
|
||||
16,
|
||||
8
|
||||
>;
|
||||
};
|
||||
|
||||
/// Partial specialization for int8_t x 8 <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <
|
||||
int K,
|
||||
typename InstructionShape,
|
||||
typename ThreadMap
|
||||
>
|
||||
struct DefaultIteratorsTensorOp<
|
||||
int8_t,
|
||||
int32_t,
|
||||
8,
|
||||
gemm::GemmShape<128, 64, K>,
|
||||
gemm::GemmShape<64, 32, K>,
|
||||
InstructionShape,
|
||||
ThreadMap> {
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
||||
gemm::GemmShape<64, 32, K>,
|
||||
InstructionShape,
|
||||
int32_t,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
8
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
||||
ThreadMap,
|
||||
int32_t,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
8
|
||||
>;
|
||||
};
|
||||
|
||||
/// Partial specialization for int8_t x 8 <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <
|
||||
int K,
|
||||
typename InstructionShape,
|
||||
typename ThreadMap
|
||||
>
|
||||
struct DefaultIteratorsTensorOp<
|
||||
int8_t,
|
||||
int32_t,
|
||||
8,
|
||||
gemm::GemmShape<64, 64, K>,
|
||||
gemm::GemmShape<32, 32, K>,
|
||||
InstructionShape,
|
||||
ThreadMap> {
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
|
||||
gemm::GemmShape<32, 32, K>,
|
||||
InstructionShape,
|
||||
int32_t,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
8
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
|
||||
ThreadMap,
|
||||
int32_t,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
8
|
||||
>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines sensible defaults for epilogues for TensorOps.
|
||||
template <
|
||||
typename Shape_,
|
||||
|
@ -98,25 +273,33 @@ struct DefaultEpilogueTensorOp {
|
|||
ElementOutput
|
||||
>;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
LayoutC
|
||||
>;
|
||||
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
|
||||
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
LayoutC>,
|
||||
cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::ElementC,
|
||||
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
|
||||
LayoutC> >::type;
|
||||
|
||||
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
/// Support several implementations depending on structure of epilogue
|
||||
using DefaultIterators = detail::DefaultIteratorsTensorOp<
|
||||
ElementOutput,
|
||||
ElementAccumulator,
|
||||
LayoutC
|
||||
kElementsPerAccess,
|
||||
Shape,
|
||||
typename WarpMmaTensorOp::Shape,
|
||||
typename WarpMmaTensorOp::Policy::Operator::Shape,
|
||||
typename OutputTileThreadMap::CompactedThreadMap
|
||||
>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
|
||||
typename OutputTileThreadMap::CompactedThreadMap,
|
||||
ElementAccumulator
|
||||
>;
|
||||
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
|
||||
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
|
||||
|
||||
/// Hard-coded padding elements added
|
||||
using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
|
||||
|
@ -184,6 +367,7 @@ struct DefaultInterleavedEpilogueTensorOp {
|
|||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -39,6 +39,7 @@
|
|||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -39,6 +39,7 @@
|
|||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -146,54 +146,6 @@ struct DefaultInterleavedThreadMapTensorOp {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines the optimal thread map for TensorOp accumulator layouts
|
||||
template <typename ThreadblockShape_, typename WarpShape_, int PartitionsK,
|
||||
typename Element_, int ElementsPerAccess, int InterleavedK>
|
||||
struct DefaultInterleavedConvThreadMapTensorOp {
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
using WarpShape = WarpShape_;
|
||||
static int const kPartitionsK = PartitionsK;
|
||||
using Element = Element_;
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kInterleavedK = InterleavedK;
|
||||
|
||||
//
|
||||
// Definitions
|
||||
//
|
||||
|
||||
struct Detail {
|
||||
/// Tensor Operations fundamentally perform operations on 8 rows
|
||||
static int const kTensorOpRows = 8;
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
|
||||
!(ThreadblockShape::kM % WarpShape::kM),
|
||||
"Divisibility");
|
||||
|
||||
/// Number of warps
|
||||
using WarpCount =
|
||||
gemm::GemmShape<ThreadblockShape::kM / WarpShape::kM,
|
||||
ThreadblockShape::kN / WarpShape::kN, kPartitionsK>;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
};
|
||||
|
||||
//
|
||||
// ThreadMap
|
||||
//
|
||||
|
||||
/// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept
|
||||
/// InterleavedOutputTileThreadMap
|
||||
using Type = InterleavedConvOutputTileThreadMap<
|
||||
MatrixShape<Detail::WarpCount::kM, Detail::WarpCount::kN>,
|
||||
MatrixShape<WarpShape::kM / Detail::kTensorOpRows,
|
||||
WarpShape::kN / InterleavedK>,
|
||||
Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -175,15 +175,106 @@ public:
|
|||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
int64_t imag_stride_dest = 0, ///< Arguments required for planar complex case - not used in real-valued case
|
||||
int64_t imag_stride_src = 0) { ///<
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
if (!output_op.is_source_needed()) {
|
||||
source_iterator.clear_mask();
|
||||
compute_source_not_needed_(output_op, destination_iterator, accumulators);
|
||||
}
|
||||
else {
|
||||
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_not_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators) { ///< Complete warp-level accumulator tile
|
||||
|
||||
//
|
||||
// Iterator over warp-level accumulator fragment
|
||||
//
|
||||
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
|
||||
|
||||
//
|
||||
// Iterate over accumulator tile
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
|
||||
|
||||
//
|
||||
// Convert and store fragment
|
||||
//
|
||||
|
||||
__syncthreads();
|
||||
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
this->warp_tile_iterator_.store(accum_fragment);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Load fragments from shared memory
|
||||
//
|
||||
|
||||
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
|
||||
|
||||
shared_load_iterator_.load(aligned_accum_fragment[0]);
|
||||
|
||||
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
|
||||
if (kPartitionsK > 1)
|
||||
{
|
||||
plus <typename SharedLoadIterator::Fragment> add_fragments;
|
||||
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
|
||||
shared_load_iterator_.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
|
||||
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
|
||||
}
|
||||
|
||||
//
|
||||
// Compute the output result
|
||||
//
|
||||
|
||||
typename OutputTileIterator::Fragment output_fragment;
|
||||
|
||||
apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment[0]);
|
||||
|
||||
|
||||
//
|
||||
// Store the final result
|
||||
//
|
||||
|
||||
destination_iterator.store(output_fragment);
|
||||
++destination_iterator;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Streams the result to global memory
|
||||
CUTLASS_DEVICE
|
||||
void compute_source_needed_(
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
OutputTileIterator destination_iterator, ///< Tile iterator for destination
|
||||
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
|
||||
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
|
||||
|
||||
typename OutputTileIterator::Fragment source_fragment;
|
||||
|
||||
source_fragment.clear();
|
||||
|
||||
|
@ -265,8 +356,6 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_(
|
||||
|
@ -294,6 +383,30 @@ private:
|
|||
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to invoke the output functor over each vector of output
|
||||
CUTLASS_DEVICE
|
||||
void apply_output_operator_source_not_needed_(
|
||||
typename OutputTileIterator::Fragment &output_fragment,
|
||||
OutputOp const &output_op, ///< Output operator
|
||||
typename SharedLoadIterator::Fragment const &aligned_accum_fragment) {
|
||||
|
||||
OutputAccessType *output_frag_ptr =
|
||||
reinterpret_cast<OutputAccessType *>(&output_fragment);
|
||||
|
||||
AccumulatorAccessType const *compute_frag_ptr =
|
||||
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
|
||||
|
||||
int const kOutputOpIterations =
|
||||
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kOutputOpIterations; ++i) {
|
||||
|
||||
// Call the output operator
|
||||
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -490,67 +490,6 @@ struct InterleavedOutputTileThreadMap {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template metaprogram for partitioning a 4D interleaved layout across warps
|
||||
/// to achieve several performance objectives:
|
||||
///
|
||||
/// - coalesced memory accesses in units of 64 Byte lines
|
||||
/// - minimal address arithmetic
|
||||
/// - minimal predicate calculations
|
||||
///
|
||||
template <typename WarpCount_, typename Iterations_, int Threads,
|
||||
int ElementsPerAccess, int ElementSize>
|
||||
struct InterleavedConvOutputTileThreadMap {
|
||||
using WarpCount = WarpCount_;
|
||||
|
||||
static int const kWarpSize = 32;
|
||||
static int const kThreads = Threads;
|
||||
static int const kWarpCount = kThreads / kWarpSize;
|
||||
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kElementSize = ElementSize;
|
||||
|
||||
//
|
||||
// Metaprogram computation
|
||||
//
|
||||
|
||||
struct Detail {};
|
||||
|
||||
//
|
||||
// Output
|
||||
//
|
||||
|
||||
using Iterations = Iterations_;
|
||||
|
||||
using Delta = MatrixShape<kWarpSize / 4, 4 * kElementsPerAccess>;
|
||||
|
||||
/// Initial offset function
|
||||
CUTLASS_HOST_DEVICE
|
||||
static MatrixCoord initial_offset(int thread_idx) {
|
||||
int warp_idx = thread_idx / kWarpSize;
|
||||
int lane_idx = thread_idx % kWarpSize;
|
||||
|
||||
// Compute warp location
|
||||
MatrixCoord warp_footprint{
|
||||
Delta::kRow * Iterations::kRow,
|
||||
Delta::kColumn * Iterations::kColumn,
|
||||
};
|
||||
|
||||
MatrixCoord warp_offset{warp_idx % WarpCount::kRow,
|
||||
warp_idx / WarpCount::kRow};
|
||||
|
||||
// Compute per-lane offset
|
||||
MatrixCoord thread_offset_in_warp{lane_idx / 4,
|
||||
(lane_idx % 4) * kElementsPerAccess};
|
||||
|
||||
MatrixCoord thread_offset_in_threadblock_tile =
|
||||
warp_footprint * warp_offset + thread_offset_in_warp;
|
||||
|
||||
return thread_offset_in_threadblock_tile;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue