CUTLASS 2.2 (#96)

Adds support for NVIDIA Ampere Architecture features. CUDA 11 Toolkit recommended.
This commit is contained in:
Andrew Kerr 2020-06-08 16:17:35 -07:00 committed by GitHub
parent e33d90b361
commit 86931fef85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
584 changed files with 51080 additions and 3373 deletions

View File

@ -2,6 +2,22 @@
# CUTLASS 2.x
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
* Fast Tensor Core operations:
* Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
* Tensor Float 32, BFloat16, and double-precision data types
* Mixed integer data types (int8, int4, bin1)
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
* Features:
* SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
* Complex-valued GEMMs targeting NVIDIA Ampere Tensor Cores in double-precision and Tensor Float 32
* Gaussian complex GEMMs using 3m complex multiply algorithm
* Universal GEMM kernel supporting two batch modes and two algorithms for parallel reductions
* Policy updates:
* [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) needed to enable NVIDIA Ampere Architecture features
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
* API to launch compiled kernel instances for GEMM and planar complex GEMM

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@ -32,7 +32,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
project(CUTLASS VERSION 2.1.0 LANGUAGES CXX)
project(CUTLASS VERSION 2.2.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
find_package(Doxygen QUIET)
@ -84,7 +84,7 @@ endif()
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
if (NOT CUDA_VERSION VERSION_LESS 7.5)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 50)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53)
endif()
if (NOT CUDA_VERSION VERSION_LESS 8.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61)
@ -98,6 +98,9 @@ endif()
if (NOT CUDA_VERSION VERSION_LESS 10.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
endif()
if (NOT CUDA_VERSION VERSION_LESS 11.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
@ -154,7 +157,7 @@ endif()
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
set(CUTLASS_ENABLE_F16C ON CACHE BOOL "Enable F16C x86 extensions in host code.")
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
#
# CUTLASS generator cmake configuration
@ -248,8 +251,8 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
endif()
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-pragma-unroll-threshold=100000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-unroll-threshold=5000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION})
@ -271,7 +274,7 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
set(NVCC_FLAGS)
set(CLANG_FLAGS)
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
set(CODES)
if(CUTLASS_NVCC_EMBED_CUBIN)
list(APPEND CODES sm_${ARCH})

View File

@ -9,15 +9,17 @@ This is the official list of CUTLASS developers and contributors.
## DEVELOPERS
Andrew Kerr
Haicheng Wu
Naila Farooqui
Manish Gupta
Dustyn Blasig
Pradeep Ramani
Manish Gupta
Aditya Atluri
Naila Farooqui
Piotr Majcher
Paul Springer
David Tanner
Scott Yokim
Jin Wang
Scott Yokim
Markus Hohnerbach
Aditya Atluri
David Tanner
## CONTRIBUTORS
Timothy Costa
@ -25,12 +27,10 @@ Julien Demouth
Brian Fahs
Michael Goldfarb
Mostafa Hagog
Markus Hohnerbach
Fei Hu
Alan Kaatz
Tina Li
Timmy Liu
Piotr Majcher
Duane Merrill
Kevin Siu
Markus Tavenrath

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@ -206,14 +206,14 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
function(cutlass_correct_source_file_language_property)
if(CUDA_COMPILER MATCHES "clang")
foreach(File ${ARGN})
if(${File} MATCHES ".*\.cu$")
if(File MATCHES ".*\.cu$")
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
endif()
endforeach()
endif()
endfunction()
set(CUTLASS_UNITY_BUILD_ENABLED ON CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
function(cutlass_unify_source_files TARGET_ARGS_VAR)

View File

@ -1,4 +1,4 @@
Copyright (c) 2017 - 2019, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 2.1
# CUTLASS 2.2
_CUTLASS 2.1 - April 2020_
_CUTLASS 2.2 - June 2020_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
@ -17,14 +17,28 @@ and applications.
To support a wide variety of applications, CUTLASS provides extensive support for
mixed-precision computations, providing specialized data-movement and
multiply-accumulate abstractions for half-precision floating
point (FP16), single-precision floating point (FP32), double-precision floating
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
single-precision floating point (FP32), double-precision floating
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations for
Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
NVIDIA's Volta and Turing architectures.
NVIDIA's Volta, Turing, and Ampere architectures.
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [functionality listing](media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
# What's New in CUTLASS 2.2
CUTLASS 2.2 is a significant update to CUTLASS adding:
- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
- Deep software pipelines using asynchronous copy
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit)
# What's New in CUTLASS 2.1
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
@ -32,7 +46,6 @@ CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
# What's New in CUTLASS 2.0
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
@ -43,9 +56,6 @@ CUTLASS 2.0 is a substantial refactoring from the previous version, intended to
**See the [CHANGELOG](CHANGELOG.md) for more details.**
See the [functionality listing](media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
# Performance
<p align="center"><img src=/media/images/cutlass-performance-plot.png></p>
@ -53,15 +63,15 @@ supported at each level of the execution model hierarchy.
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
they exhibit performance comparable to cuBLAS for scalar GEMM
computations. The above figure shows CUTLASS performance relative to cuBLAS
for large matrix dimensions on an NVIDIA GeForce 2080 Ti and an NVIDIA TitanV
using CUDA 10.2. Tensor Core operations are implemented using CUDA's
for large matrix dimensions on an NVIDIA GeForce 2080 Ti, an NVIDIA A100, and an NVIDIA TitanV
using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
# Compatibility
CUTLASS requires a C++11 host compiler and
performs best when built with the [CUDA 10.2 Toolkit](https://developer.nvidia.com/cuda-toolkit).
It is compatible with CUDA 9.2, CUDA 10.0, and CUDA 10.1.
performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit).
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2.
We have tested the following environments.
@ -70,27 +80,28 @@ We have tested the following environments.
| Windows 10 | Microsoft Visual Studio 2015|
| | Microsoft Visual Studio 2017|
| Ubuntu 16.04 | GCC 5.4.0 |
| Ubuntu 18.04 | GCC 7.3.0 |
| Ubuntu 18.04 | GCC 7.5.0 |
Additionally, CUTLASS may be built with clang.
See [these instructions](media/docs/quickstart.md#clang) for more details.
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
any Maxwell-, Pascal-, Volta-, or Turing- architecture NVIDIA GPU.
any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|**GPU**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|---|---|---|
|NVIDIA GeForce 1080|9.2| |
|NVIDIA TitanXP|9.2| |
|NVIDIA Tesla P100|9.2| |
|NVIDIA Tesla V100|9.2|10.1|
|NVIDIA TitanV|9.2|10.1|
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|10.0|10.2|
|NVIDIA Tesla T4|10.0|10.2|
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|---|---|---|---|
|NVIDIA Tesla P100|6.0|9.2| |
|NVIDIA GeForce 1080|6.1|9.2| |
|NVIDIA TitanXP|6.1|9.2| |
|NVIDIA Tesla V100|7.0|9.2|10.1|
|NVIDIA TitanV|7.0|9.2|10.1|
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|NVIDIA Tesla T4|7.5|10.0|10.2|
|NVIDIA A100|8.0|11.0|11.0|
# Documentation
CUTLASS 2.1 is described in the following documents and the accompanying
CUTLASS 2.2 is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
@ -124,7 +135,7 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
```
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify
the architectures to build CUTLASS for by changing the CMake configuration setting
`CUTLASS_NVCC_ARCHS`.
@ -210,6 +221,10 @@ examples/
10_planar_complex/ # example demonstrating planar complex GEMM kernels
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
```
### Tools
@ -255,29 +270,32 @@ $ make cutlass_profiler -j
Example command line for profiling SGEMM kernels is as follows:
```
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=4352 --n=4096 --k=4096
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
=============================
Problem ID: 1
Provider: CUTLASS
Operation: cutlass_simt_sgemm_128x128_nn
Provider: CUTLASS
OperationKind: gemm
Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1
Disposition: Passed
Status: Success
Status: Success
Verification: ON
Disposition: Passed
Arguments: --m=4352 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 \
--split_k_slices=1 --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 \
--stages=2 --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 \
--max_cc=1024
cuBLAS: Passed
Bytes: 52428800 bytes
FLOPs: 146064539648 flops
Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \
--batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
--warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
Runtime: 10.5424 ms
Memory: 4.63158 GiB/s
Bytes: 180355072 bytes
FLOPs: 115992428544 flops
Math: 13854.9 GFLOP/s
Runtime: 6.73655 ms
Memory: 24.934 GiB/s
Math: 17218.4 GFLOP/s
```
[Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -10,28 +10,35 @@ if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR
message(STATUS "cuBLAS Disabled.")
elseif(NOT TARGET cublas)
find_path(
_CUBLAS_INCLUDE_DIR cublas.h
PATHS
${CUDA_TOOLKIT_ROOT_DIR}/include
$ENV{CUBLAS_PATH}/include
$ENV{CUDA_PATH}/include
${CUBLAS_PATH}/include
/usr/include)
_CUBLAS_INCLUDE_DIR
NAMES cublas.h
HINTS
${CUBLAS_INCLUDE_PATH}
ENV CUBLAS_INCLUDE_PATH
${CUBLAS_PATH}
ENV CUBLAS_PATH
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
include
)
find_library(
_CUBLAS_LIBRARY cublas
_CUBLAS_LIBRARY
NAMES cublas
HINTS
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
$ENV{CUBLAS_PATH}/lib64
$ENV{CUBLAS_PATH}/lib/x64
$ENV{CUDA_PATH}/lib64
$ENV{CUDA_PATH}/lib/x64
${CUBLAS_PATH}/lib64
${CUBLAS_PATH}/lib/x64
/usr/lib/x86_64-linux-gnu)
${CUBLAS_LIBRARY_PATH}
ENV CUBLAS_LIBRARY_PATH
${_CUBLAS_INCLUDE_DIR}/..
${CUBLAS_PATH}
ENV CUBLAS_PATH
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib64
lib/x64
lib
)
if(_CUBLAS_INCLUDE_DIR AND _CUBLAS_LIBRARY)
@ -79,17 +86,20 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
$<BUILD_INTERFACE:${CUBLAS_INCLUDE_DIR}>)
find_library(
_CUBLASLT_LIBRARY cublasLt
_CUBLASLT_LIBRARY
NAMES cublasLt
HINTS
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
$ENV{CUBLAS_PATH}/lib64
$ENV{CUBLAS_PATH}/lib/x64
$ENV{CUDA_PATH}/lib64
$ENV{CUDA_PATH}/lib/x64
${CUBLAS_PATH}/lib64
${CUBLAS_PATH}/lib/x64
/usr/lib/x86_64-linux-gnu)
${CUBLAS_LIBRARY_PATH}
ENV CUBLAS_LIBRARY_PATH
${_CUBLAS_INCLUDE_DIR}/..
${CUBLAS_PATH}
ENV CUBLAS_PATH
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib64
lib/x64
lib
)
if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt)
@ -106,6 +116,8 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
add_library(nvidia::cublasLt ALIAS cublasLt)
target_link_libraries(cublas INTERFACE cublasLt)
endif()
endif()

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -34,6 +34,8 @@
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor_op_multiplicand_sm70.h"
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
#include "cutlass/layout/tensor_op_multiplicand_sm80.h"
#include "visualize_layout.h"
#include "register_layout.h"
@ -59,18 +61,40 @@ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase>
// Integer matrix multiply.int4 8832 TN kblock128
{"TensorOpMultiplicand<4,128>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 128>>},
// Integer matrix multiply.int4 16864 TN kblock256
{"TensorOpMultiplicand<4,256>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<4, 256>>},
// Integer matrix multiply 8816 Interleaved-32
{"TensorOpMultiplicand<8,32>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 32>>},
// Integer matrix multiply 8816 TN kblock64
{"TensorOpMultiplicand<8,64>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 64>>},
{"TensorOpMultiplicand<8,128>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<8, 128>>},
// Matrix Multiply 1688 TN kblock32
{"TensorOpMultiplicand<16,32>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 32>>},
// Matrix multiply 1688 NT
{"TensorOpMultiplicand<16,64>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<16, 64>>},
// Matrix multiply 1688.TF32 TN kblock16
{"TensorOpMultiplicand<32,16>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<32, 16>>},
// Matrix multiply 1688.TF32 TN kblock32
{"TensorOpMultiplicand<32,32>",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand<32, 32>>},
// Matrix multiply 1688 NT
{"TensorOpMultiplicandCongruous<32,32>",
new VisualizeLayout<
cutlass::layout::TensorOpMultiplicandCongruous<32, 32>>},
// Matrix multiply 884 NT
{"TensorOpMultiplicandCongruous<64,16>",
new VisualizeLayout<
cutlass::layout::TensorOpMultiplicandCongruous<64, 16>>},
// Matrix multiply 884 TN
{"TensorOpMultiplicand64bCrosswise",
new VisualizeLayout<cutlass::layout::TensorOpMultiplicand64bCrosswise>},
{"TensorOpMultiplicandCongruous<128,4>",
new VisualizeLayout<
cutlass::layout::TensorOpMultiplicandCongruous<128, 4>>},
@ -82,7 +106,7 @@ void RegisterLayouts(std::map<std::string, std::unique_ptr<VisualizeLayoutBase>
cutlass::layout::VoltaTensorOpMultiplicandCongruous<16>>},
{"VoltaTensorOpMultiplicandCrosswise<16,32>",
new VisualizeLayout<
cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>},
cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>}
};
for (auto layout : layout_pairs) {

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -65,14 +65,26 @@ void print_usage(std::ostream &out) {
"--extent=64,64 --vectorize=32 --output-shape=256,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<4,128>\" "
"--extent=128,32 --vectorize=32 --output-shape=256,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<4,256>\" "
"--extent=256,16 --vectorize=32 --output-shape=256,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<8,32>\" "
"--extent=32,64 --vectorize=16 --output-shape=128,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<8,64>\" "
"--extent=64,32 --vectorize=16 --output-shape=128,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<8,128>\" "
"--extent=128,16 --vectorize=16 --output-shape=128,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<16,32>\" "
"--extent=32,32 --vectorize=8 --output-shape=64,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<16,64>\" "
"--extent=64,16 --vectorize=8 --output-shape=64,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<32,16>\" "
"--extent=16,32 --vectorize=4 --output-shape=32,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<32,32>\" "
"--extent=32,16 --vectorize=4 --output-shape=32,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicandCongruous<32,32>\" "
"--extent=32,16 --vectorize=4 --output-shape=32,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicandCongruous<64, 16>\" "
"--extent=16,16 --vectorize=2 --output-shape=16,4\n"
<< "$ 03_visualize_layout \"VoltaTensorOpMultiplicandCrosswise<16,32>\" "
"--extent=32,64 --vectorize=4 --output-shape=64,4\n"
<< "$ 03_visualize_layout \"VotlaTensorOpMultiplicandCongruous<16>\" "

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -39,7 +39,7 @@ inner product (1/16th of output), they accumulate to single output matrix.
Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing
high performance kernels at scale which works for multiple problem sizes with good abstractions is
really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose
really hard. CUTLASS solves this problem by providing simplified abstractions to compose
multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU
easily.
@ -144,7 +144,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// This code section describes ?
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
@ -172,17 +172,7 @@ using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
ShapeMMAOp,
EpilogueOp>;
int main() {
//
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
//
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
return -1;
}
int run() {
cudaDeviceProp props;
@ -316,11 +306,30 @@ int main() {
tensor_ref_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
tensor_ref_d.host_view())
? "Passed"
: "Failed")
<< std::endl;
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
CUTLASS_CHECK(status);
std::cout << (passed ? "Passed" : "Failed") << std::endl;
return (passed ? 0 : -1);
}
int main() {
//
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
//
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
// Returning zero, so this test passes when built with older CUDA Toolkits. Its action are no-op.
return 0;
}
else {
return run();
}
}

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -156,7 +156,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// This code section describes ?
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
@ -188,15 +188,7 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
SwizzleThreadBlock,
NumStages>;
int main() {
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
//
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
return -1;
}
int run() {
cudaDeviceProp props;
@ -223,7 +215,7 @@ int main() {
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.nk()); // <- Create matrix B with dimensions N x K
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
@ -326,12 +318,28 @@ int main() {
tensor_ref_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
tensor_ref_d.host_view())
? "Passed"
: "Failed")
<< std::endl;
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
CUTLASS_CHECK(status);
return 0;
std::cout << (passed ? "Passed" : "Failed") << std::endl;
return (passed ? 0 : -1);
}
int main() {
// Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
//
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
// Returning zero when built on older Toolkits so tests pass. The actions of this SDK example are no-op.
return 0;
}
else {
return run();
}
}

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -150,12 +150,12 @@ using SmArch = cutlass::arch::Sm75;
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 256, 64>; // <- threadblock tile M = 128, N = 256, K = 64
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 16
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 64
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ??
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
@ -186,7 +186,7 @@ using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
SwizzleThreadBlock,
NumStages>;
int main() {
int run() {
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 10.2.
@ -222,7 +222,7 @@ int main() {
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.nk()); // <- Create matrix B with dimensions N x K
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
@ -325,12 +325,28 @@ int main() {
tensor_ref_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
tensor_ref_d.host_view())
? "Passed"
: "Failed")
<< std::endl;
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
CUTLASS_CHECK(status);
return 0;
std::cout << (passed ? "Passed" : "Failed") << std::endl;
return (passed ? 0 : -1);
}
int main() {
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 10.2.
//
// CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
// Returning zero so this test passes when built on older Toolkits.
return 0;
}
else {
return run();
}
}

View File

@ -500,7 +500,9 @@ int main(int argc, char const **args) {
if (props.major < 7) {
std::cerr << "Volta Tensor Core operations must be run on a machine with compute capability at least 70."
<< std::endl;
return -1;
// Returning zero so this test passes on older architectures even though its actions are no-op.
return 0;
}
else if (props.major == 7 && props.minor <= 2) {
//
@ -508,7 +510,9 @@ int main(int argc, char const **args) {
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
return -1;
// Returning zero so this test passes on older Toolkits even though its actions are no-op.
return 0;
}
}
else if (props.major == 7 && props.minor >= 5) {
@ -517,7 +521,9 @@ int main(int argc, char const **args) {
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
return -1;
// Returning zero so this test passes on older Toolkits even though its actions are no-op.
return 0;
}
}

View File

@ -560,7 +560,9 @@ int main(int argc, char const **args) {
if (props.major < 7) {
std::cerr << "Tensor Core operations must be run on a machine with compute capability at least 70."
<< std::endl;
return -1;
// Returning zero so this passes on older architectures. Its actions are no-op.
return 0;
}
else if (props.major == 7 && props.minor <= 2) {
//
@ -568,7 +570,9 @@ int main(int argc, char const **args) {
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
return -1;
// Returning zero so this passes on older Toolkits. Its actions are no-op.
return 0;
}
}
else if (props.major == 7 && props.minor >= 5) {
@ -577,7 +581,9 @@ int main(int argc, char const **args) {
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
return -1;
// Returning zero so this passes on older Toolkits. Its actions are no-op.
return 0;
}
}

View File

@ -0,0 +1,27 @@
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this list of
# conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or other materials
# provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
12_gemm_bias_relu
gemm_bias_relu.cu
)

View File

@ -0,0 +1,282 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
*/
#include <algorithm>
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices. Column Major for
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm75;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
//
// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + beta * c_ij )
//
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 2;
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;
int run() {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (!(props.major * 10 + props.minor >= 75)) {
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
<< std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
const int length_m = 5120;
const int length_n = 4096;
const int length_k = 4096;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.nk()); // <- Create matrix B with dimensions N x K
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias(
{problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// CUTLASS kernel
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// reference kernel
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c_bias.host_view(),
1,
ElementOutput(4),
ElementOutput(-4),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c_bias.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
// Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{
problem_size, // <- problem size of matrix multiplication
tensor_a.device_ref(), // <- reference to matrix A on device
tensor_b.device_ref(), // <- reference to matrix B on device
{tensor_c_bias.device_data(), 0}, // <- the C matrix is treated as the bias vector. We can enable the GEMM
// to project away the N dimension by setting the stride to zero.
tensor_d.device_ref(), // <- reference to matrix D on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Initialize CUTLASS kernel with arguments and workspace pointer
cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
//
// Create instantiation for device reference gemm kernel
//
cutlass::reference::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementComputeEpilogue>
gemm_device_reference;
// Launch device reference to compute strictly the product A * B
gemm_device_reference(
problem_size,
alpha,
tensor_a.device_ref(),
tensor_b.device_ref(),
0,
tensor_c_bias.device_ref(),
tensor_ref_d.device_ref());
// Wait for kernels to finish
cudaDeviceSynchronize();
// Copy output data from CUTLASS and reference kernel to host for comparison
tensor_d.sync_host();
tensor_ref_d.sync_host();
// Compute bias + relu in host code
for (int i = 0; i < problem_size.m(); ++i) {
for (int j = 0; j < problem_size.n(); ++j) {
tensor_ref_d.at({i, j}) = std::max(
ElementOutput(0),
ElementOutput(tensor_ref_d.at({i, j}) + beta * tensor_c_bias.at({i, 0}))
);
}
}
// Check if output from CUTLASS kernel and reference kernel are equal or not
std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
tensor_ref_d.host_view())
? "Passed"
: "Failed")
<< std::endl;
CUTLASS_CHECK(status);
return 0;
}
int main() {
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
//
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
else {
return run();
}
}

View File

@ -0,0 +1,33 @@
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this list of
# conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or other materials
# provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cutlass_example_add_executable(
13_fused_two_gemms
fused_gemm.cu
)
target_include_directories(
13_fused_two_gemms
PRIVATE
.
)

View File

@ -0,0 +1,190 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "device/b2b_gemm.h"
#include "b2b_gemm_run.h"
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
void run_nonfused_gemm_f16() {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t;
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
ElementCompute alpha0 = ElementCompute(2);
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2);
ElementCompute beta1 = ElementCompute(1);
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Gemm0 = cutlass::gemm::device::Gemm<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
ThreadblockShape0,
WarpShape0,
InstructionShape,
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2
>;
using Gemm1 = cutlass::gemm::device::Gemm<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
ThreadblockShape1,
WarpShape1,
InstructionShape,
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2
>;
B2bNonFusedGemmRun<Gemm0, Gemm1> nonFusedGemm;
std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
if(pass)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
}
void run_fused_gemm_f16() {
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementCompute = cutlass::half_t;
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
ElementCompute alpha0 = ElementCompute(2);
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2);
ElementCompute beta1 = ElementCompute(1);
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
InstructionShape::kM * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute
>;
using B2bGemm = cutlass::gemm::device::B2bGemm<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2
>;
B2bFusedGemmRun<B2bGemm> fusedGemm;
std::cout << "Running Fused back-to-back FP16 TN GEMMs...\n";
bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
}
////////////////////////////////////////////////////////////////////////////////
#endif //#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)

View File

@ -0,0 +1,608 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_relu.h"
#include "helper.h"
#define CHECK_GT(val1, val2) \
if((val1) <= (val2)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
#define CHECK_TRUE(val) \
if(!(val)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
////////////////////////////////////////////////////////////////////////////////
template <typename Gemm0_, typename Gemm1_>
struct B2bNonFusedGemmRun
{
using Gemm0 = Gemm0_;
using Gemm1 = Gemm1_;
using ElementAccumulator = typename Gemm0::ElementAccumulator;
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
uint64_t seed;
//
// Methods
//
B2bNonFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else {
// TODO: Implement the rest
std::cerr << "Not implemented\n";
return false;
}
return true;
}
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
cutlass::reference::host::TensorFill(
tensor_D0.host_view());
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
cutlass::reference::host::TensorFill(
reference_D0.host_view());
cutlass::reference::host::TensorFill(
reference_D1.host_view());
tensor_A0.sync_device();
tensor_B0.sync_device();
tensor_C0.sync_device();
tensor_D0.sync_device();
tensor_B1.sync_device();
tensor_C1.sync_device();
tensor_D1.sync_device();
reference_D0.sync_device();
reference_D1.sync_device();
//
// Initialize the GEMM operator
//
typename Gemm0::Arguments arguments_0{
problem_size_0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
tensor_C0.device_ref(),
tensor_D0.device_ref(),
{alpha0, beta0}
};
typename Gemm1::Arguments arguments_1{
problem_size_1,
tensor_D0.device_ref(),
tensor_B1.device_ref(),
tensor_C1.device_ref(),
tensor_D1.device_ref(),
{alpha1, beta1}
};
Gemm0 gemm_op_0;
Gemm1 gemm_op_1;
cutlass::Status status = gemm_op_0.initialize(arguments_0);
CUTLASS_CHECK(status);
status = gemm_op_1.initialize(arguments_1);
CUTLASS_CHECK(status);
//
// Run the GEMM
//
cudaEvent_t start, stop1, stop2;
cudaEventCreate(&start);
cudaEventCreate(&stop1);
cudaEventCreate(&stop2);
cudaEventRecord(start);
for(int i = 0; i < 100; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < 100; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop2);
cudaDeviceSynchronize();
float gemm0Time, gemm1Time, totalTime;
cudaEventElapsedTime(&gemm0Time, start, stop1);
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
cudaEventElapsedTime(&totalTime, start, stop2);
std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
std::cout << "total time " << totalTime / 100.0 << " ms\n";
tensor_D0.sync_host();
tensor_D1.sync_host();
//
// Verify
//
cutlass::reference::device::Gemm<
typename Gemm0::ElementA, typename Gemm0::LayoutA,
typename Gemm0::ElementB, typename Gemm0::LayoutB,
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
ElementAccumulator, typename Gemm0::Operator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename Gemm1::ElementA, typename Gemm1::LayoutA,
typename Gemm1::ElementB, typename Gemm1::LayoutB,
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
ElementAccumulator, typename Gemm1::Operator>
reference_gemm_1;
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
tensor_C0.device_ref(),
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
tensor_C1.device_ref(),
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
// Wait for kernels to finish
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
if (!passed) {
std::stringstream fname;
fname << "error_B2bGemm_device_nonfused.txt";
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
<< "\nD0 =\n" << tensor_D0.host_view()
<< "\nB1 =\n" << tensor_B1.host_view()
<< "\nC1 =\n" << tensor_C1.host_view()
<< "\n\nReference =\n" << reference_D1.host_view()
<< "\nComputed =\n" << tensor_D1.host_view();
}
return passed;
}
};
template <typename B2bGemm_>
struct B2bFusedGemmRun
{
using B2bGemm = B2bGemm_;
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
uint64_t seed;
//
// Methods
//
B2bFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else {
// TODO: Implement the rest
std::cerr << "Not implemented\n";
return false;
}
return true;
}
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
// cutlass::HostTensor<
// typename B2bGemm::ElementC,
// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
cutlass::reference::host::TensorFill(
reference_D0.host_view());
cutlass::reference::host::TensorFill(
reference_D1.host_view());
tensor_A0.sync_device();
tensor_B0.sync_device();
tensor_C0.sync_device();
tensor_B1.sync_device();
tensor_C1.sync_device();
tensor_D1.sync_device();
reference_D0.sync_device();
reference_D1.sync_device();
//
// Initialize the GEMM operator
//
typename B2bGemm::Arguments arguments{
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
tensor_C0.device_ref(),
tensor_B1.device_ref(),
tensor_C1.device_ref(),
tensor_D1.device_ref(),
{alpha0, beta0},
{alpha1, beta1},
};
B2bGemm b2b_gemm_op;
cutlass::Status status = b2b_gemm_op.initialize(arguments);
CUTLASS_CHECK(status);
//
// Run the GEMM
//
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaEventRecord(start);
for(int i = 0; i < 100; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaDeviceSynchronize();
float gemmTime;
cudaEventElapsedTime(&gemmTime, start, stop);
std::cout << "time " << gemmTime / 100.0 << " ms\n";
//tensor_D0.sync_host();
tensor_D1.sync_host();
//
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_0, reference_gemm_1;
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
tensor_C0.device_ref(),
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
tensor_C1.device_ref(),
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
if (!passed) {
std::stringstream fname;
fname << "error_B2bGemm_device_fused.txt";
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
// << "\nD0 =\n" << tensor_D0.host_view()
<< "\nB1 =\n" << tensor_B1.host_view()
<< "\nC1 =\n" << tensor_C1.host_view()
<< "\n\nReference =\n" << reference_D1.host_view()
<< "\nComputed =\n" << tensor_D1.host_view();
}
return passed;
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,190 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "device/b2b_gemm.h"
#include "b2b_interleaved_gemm_run.h"
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
void run_nonfused_gemm_s8() {
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementCompute = float;
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
ElementCompute alpha0 = ElementCompute(2);
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2);
ElementCompute beta1 = ElementCompute(1);
using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 32, 64>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
using Gemm0 = cutlass::gemm::device::Gemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
ThreadblockShape0,
WarpShape0,
InstructionShape,
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2
>;
using Gemm1 = cutlass::gemm::device::Gemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
ThreadblockShape1,
WarpShape1,
InstructionShape,
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2
>;
B2bInterleavedNonFusedGemmRun<Gemm0, Gemm1, 32> nonFusedGemm;
std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
if(pass)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
}
void run_fused_gemm_s8() {
using ElementOutput = int8_t;
using ElementAccumulator = int32_t;
using ElementCompute = float;
cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
ElementCompute alpha0 = ElementCompute(2);
ElementCompute beta0 = ElementCompute(0);
ElementCompute alpha1 = ElementCompute(2);
ElementCompute beta1 = ElementCompute(1);
using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
using EpilogueOutputOp0 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
InstructionShape::kM * InstructionShape::kN / 32,
ElementAccumulator,
ElementCompute
>;
using EpilogueOutputOp1 =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
64 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute
>;
using B2bGemm = cutlass::gemm::device::B2bGemm<
int8_t,
cutlass::layout::ColumnMajorInterleaved<32>,
int8_t,
cutlass::layout::RowMajorInterleaved<32>,
ElementOutput,
cutlass::layout::ColumnMajorInterleaved<32>,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
2
>;
B2bInterleavedFusedGemmRun<B2bGemm, 32> fusedGemm;
std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
if(passed)
std::cout << "Pass\n";
else
std::cout << "Fail\n";
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)

View File

@ -0,0 +1,633 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/reference/device/gemm.h"
#include "helper.h"
#define CHECK_GT(val1, val2) \
if((val1) <= (val2)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
#define CHECK_TRUE(val) \
if(!(val)) \
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
template <typename Gemm0_, typename Gemm1_, int InterleavedK_>
struct B2bInterleavedNonFusedGemmRun
{
using Gemm0 = Gemm0_;
using Gemm1 = Gemm1_;
using ElementAccumulator = typename Gemm0::ElementAccumulator;
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
uint64_t seed;
//
// Methods
//
B2bInterleavedNonFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else {
// TODO: Implement the rest
std::cerr << "Not implemented\n";
return false;
}
return true;
}
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename Gemm0::ElementA,
typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementB,
typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm0::ElementC,
typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementB,
typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename Gemm1::ElementC,
typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
//Reorder B0 and B1
cutlass::reorder_column<InterleavedK_>(
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
cutlass::reorder_column<InterleavedK_>(
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
cutlass::reference::host::TensorFill(
tensor_D0.host_view());
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
cutlass::reference::host::TensorFill(
reference_D0.host_view());
cutlass::reference::host::TensorFill(
reference_D1.host_view());
tensor_A0.sync_device();
tensor_B0.sync_device();
tensor_B0_reordered.sync_device();
tensor_C0.sync_device();
tensor_D0.sync_device();
tensor_B1.sync_device();
tensor_B1_reordered.sync_device();
tensor_C1.sync_device();
tensor_D1.sync_device();
reference_D0.sync_device();
reference_D1.sync_device();
//
// Initialize the GEMM operator
//
typename Gemm0::Arguments arguments_0{
problem_size_0,
tensor_A0.device_ref(),
tensor_B0_reordered.device_ref(),
tensor_C0.device_ref(),
tensor_D0.device_ref(),
{alpha0, beta0}
};
typename Gemm1::Arguments arguments_1{
problem_size_1,
tensor_D0.device_ref(),
tensor_B1_reordered.device_ref(),
tensor_C1.device_ref(),
tensor_D1.device_ref(),
{alpha1, beta1}
};
Gemm0 gemm_op_0;
Gemm1 gemm_op_1;
cutlass::Status status = gemm_op_0.initialize(arguments_0);
CUTLASS_CHECK(status);
status = gemm_op_1.initialize(arguments_1);
CUTLASS_CHECK(status);
//
// Run the GEMM
//
cudaEvent_t start, stop1, stop2;
cudaEventCreate(&start);
cudaEventCreate(&stop1);
cudaEventCreate(&stop2);
cudaEventRecord(start);
for(int i = 0; i < 100; i++) {
status = gemm_op_0();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop1);
for(int i = 0; i < 100; i++) {
status = gemm_op_1();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop2);
cudaDeviceSynchronize();
float gemm0Time, gemm1Time, totalTime;
cudaEventElapsedTime(&gemm0Time, start, stop1);
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
cudaEventElapsedTime(&totalTime, start, stop2);
std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
std::cout << "total time " << totalTime / 100.0 << " ms\n";
tensor_D0.sync_host();
tensor_D1.sync_host();
//
// Verify
//
cutlass::reference::device::Gemm<
typename Gemm0::ElementA, typename Gemm0::LayoutA,
typename Gemm0::ElementB, typename Gemm0::LayoutB,
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
ElementAccumulator, typename Gemm0::Operator>
reference_gemm_0;
cutlass::reference::device::Gemm<
typename Gemm1::ElementA, typename Gemm1::LayoutA,
typename Gemm1::ElementB, typename Gemm1::LayoutB,
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
ElementAccumulator, typename Gemm1::Operator>
reference_gemm_1;
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
tensor_C0.device_ref(),
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
tensor_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
tensor_C1.device_ref(),
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
if (!passed) {
std::stringstream fname;
fname << "error_B2bGemm_device_interleaved_nonfused.txt";
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
<< "\nD0 =\n" << tensor_D0.host_view()
<< "\nB1 =\n" << tensor_B1.host_view()
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
<< "\nC1 =\n" << tensor_C1.host_view()
<< "\n\nReference =\n" << reference_D1.host_view()
<< "\nComputed =\n" << tensor_D1.host_view();
}
return passed;
}
};
template <typename B2bGemm_, int InterleavedK_>
struct B2bInterleavedFusedGemmRun
{
using B2bGemm = B2bGemm_;
using ElementAccumulator = typename B2bGemm::ElementAccumulator;
using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
uint64_t seed;
//
// Methods
//
B2bInterleavedFusedGemmRun(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = 2080
):
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
/// Helper to initialize a tensor view
template <typename Element, typename Layout>
bool initialize_tensor(
cutlass::TensorView<Element, Layout> view,
cutlass::Distribution::Kind dist_kind,
uint64_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
cutlass::reference::host::TensorFillRandomUniform(
view, seed, 2, -2, 0);
}
else if (dist_kind == cutlass::Distribution::Identity) {
cutlass::reference::host::TensorFillIdentity(view);
}
else if (dist_kind == cutlass::Distribution::Sequential) {
cutlass::reference::host::BlockFillSequential(
view.data(), view.capacity());
}
else {
// TODO: Implement the rest
std::cerr << "Not implemented\n";
return false;
}
return true;
}
/// Executes one test
bool run(
cutlass::gemm::GemmCoord problem_size_0,
cutlass::gemm::GemmCoord problem_size_1,
ElementCompute alpha0 = ElementCompute(1),
ElementCompute beta0 = ElementCompute(0),
ElementCompute alpha1 = ElementCompute(1),
ElementCompute beta1 = ElementCompute(0),
bool relu = true) {
//
// Allocate the GEMM workspace
//
cutlass::HostTensor<
typename B2bGemm::ElementA,
typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
// cutlass::HostTensor<
// typename B2bGemm::ElementC,
// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementB,
typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
cutlass::HostTensor<
typename B2bGemm::ElementC,
typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
//Reorder B0
cutlass::reorder_column<B2bGemm::InstructionShape::kK>(
tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
cutlass::reorder_column<InterleavedK_>(
tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
cutlass::reference::host::TensorFill(
tensor_D1.host_view());
cutlass::reference::host::TensorFill(
reference_D0.host_view());
cutlass::reference::host::TensorFill(
reference_D1.host_view());
tensor_A0.sync_device();
tensor_B0.sync_device();
tensor_B0_reordered.sync_device();
tensor_C0.sync_device();
//tensor_D0.sync_device();
tensor_B1.sync_device();
tensor_B1_reordered.sync_device();
tensor_C1.sync_device();
tensor_D1.sync_device();
reference_D0.sync_device();
reference_D1.sync_device();
//
// Initialize the GEMM operator
//
typename B2bGemm::Arguments arguments{
problem_size_0,
problem_size_1,
tensor_A0.device_ref(),
tensor_B0_reordered.device_ref(),
tensor_C0.device_ref(),
tensor_B1_reordered.device_ref(),
tensor_C1.device_ref(),
tensor_D1.device_ref(),
{alpha0, beta0},
{alpha1, beta1},
1, /*threadblock_swizzle_k_tile*/
};
B2bGemm b2b_gemm_op;
cutlass::Status status = b2b_gemm_op.initialize(arguments);
CUTLASS_CHECK(status);
//
// Run the GEMM
//
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaEventRecord(start);
for(int i = 0; i < 100; i++) {
status = b2b_gemm_op();
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaDeviceSynchronize();
float gemmTime;
cudaEventElapsedTime(&gemmTime, start, stop);
std::cout << "time " << gemmTime / 100.0 << " ms\n";
//tensor_D0.sync_host();
tensor_D1.sync_host();
//
// Verify
//
cutlass::reference::device::Gemm<
typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
ElementAccumulator, typename B2bGemm::Operator>
reference_gemm_0, reference_gemm_1;
reference_gemm_0(
problem_size_0,
alpha0,
tensor_A0.device_ref(),
tensor_B0.device_ref(),
beta0,
tensor_C0.device_ref(),
reference_D0.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D0.device_view());
}
reference_gemm_1(
problem_size_1,
alpha1,
reference_D0.device_ref(),
tensor_B1.device_ref(),
beta1,
tensor_C1.device_ref(),
reference_D1.device_ref()
);
if(relu) {
cutlass::reference::device::TensorReLu(reference_D1.device_view());
}
cudaDeviceSynchronize();
reference_D0.sync_host();
reference_D1.sync_host();
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
bool passed = cutlass::reference::host::TensorEquals(
reference_D1.host_view(),
tensor_D1.host_view());
CHECK_TRUE(passed);
if (!passed) {
std::stringstream fname;
fname << "error_B2bGemm_device_interleaved_fused.txt";
std::cerr << "Dumping results in " << fname.str() << "\n";
std::ofstream file(fname.str());
file
<< "A0 =\n" << tensor_A0.host_view()
<< "\nB0 =\n" << tensor_B0.host_view()
<< "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
<< "\nC0 =\n" << tensor_C0.host_view()
// << "\nD0 =\n" << tensor_D0.host_view()
<< "\nB1 =\n" << tensor_B1.host_view()
<< "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
<< "\nC1 =\n" << tensor_C1.host_view()
<< "\n\nReference =\n" << reference_D1.host_view()
<< "\nComputed =\n" << tensor_D1.host_view();
}
return passed;
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,439 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/device_kernel.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "kernel/b2b_gemm.h"
#include "kernel/default_b2b_gemm.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace device {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator_ = ElementC_,
/// Operator class tag
typename OperatorClass_ = arch::OpClassSimt,
/// Tag indicating architecture to tune for
typename ArchTag_ = arch::Sm70,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Epilogue output operator
typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
/// Number of stages used in the pipelined mainloop
int Stages =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kStages,
/// Access granularity of A matrix in units of elements
int AlignmentA =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentA,
/// Access granularity of B matrix in units of elements
int AlignmentB =
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
ElementC_, ElementAccumulator_>::kAlignmentB,
/// If true, kernel supports split-K with serial reduction
bool SplitKSerial = false,
/// Operation performed by GEMM
typename Operator_ = typename DefaultGemmConfiguration<
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
ElementAccumulator_>::Operator,
/// Whether Beta is zero or not
bool IsBetaZero = false>
class B2bGemm {
public:
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
using ElementAccumulator = ElementAccumulator_;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape0 = ThreadblockShape0_;
using ThreadblockShape1 = ThreadblockShape1_;
using WarpShape0 = WarpShape0_;
using WarpShape1 = WarpShape1_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp0 = EpilogueOutputOp0_;
using EpilogueOutputOp1 = EpilogueOutputOp1_;
using ThreadblockSwizzle = ThreadblockSwizzle_;
using Operator = Operator_;
static int const kStages = Stages;
static int const kAlignmentA = AlignmentA;
static int const kAlignmentB = AlignmentB;
static int const kAlignmentC = EpilogueOutputOp1::kCount;
static bool const kSplitKSerial = SplitKSerial;
static bool const kIsBetaZero = IsBetaZero;
static ComplexTransform const kTransformA = ComplexTransform::kNone;
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Define the kernel
using B2bGemmKernel = typename kernel::DefaultB2bGemm<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementC,
LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
ThreadblockSwizzle,
kStages,
kSplitKSerial,
Operator,
kIsBetaZero
>::B2bGemmKernel;
/// Argument structure
struct Arguments {
//
// Data members
//
GemmCoord problem_size_0;
GemmCoord problem_size_1;
TensorRef<ElementA const, LayoutA> ref_A0;
TensorRef<ElementB const, LayoutB> ref_B0;
TensorRef<ElementC const, LayoutC> ref_C0;
TensorRef<ElementB const, LayoutB> ref_B1;
TensorRef<ElementC const, LayoutC> ref_C1;
TensorRef<ElementC, LayoutC> ref_D1;
typename EpilogueOutputOp0::Params epilogue0;
typename EpilogueOutputOp1::Params epilogue1;
int split_k_slices;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
}
/// Constructs an Arguments structure
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_0_,
GemmCoord problem_size_1_,
TensorRef<ElementA const, LayoutA> ref_A0_,
TensorRef<ElementB const, LayoutB> ref_B0_,
TensorRef<ElementC const, LayoutC> ref_C0_,
TensorRef<ElementB const, LayoutB> ref_B1_,
TensorRef<ElementC const, LayoutC> ref_C1_,
TensorRef<ElementC, LayoutC> ref_D1_,
typename EpilogueOutputOp0::Params epilogue0_ =
typename EpilogueOutputOp0::Params(),
typename EpilogueOutputOp1::Params epilogue1_ =
typename EpilogueOutputOp1::Params(),
int split_k_slices_ = 1
):
problem_size_0(problem_size_0_),
problem_size_1(problem_size_1_),
ref_A0(ref_A0_),
ref_B0(ref_B0_),
ref_C0(ref_C0_),
ref_B1(ref_B1_),
ref_C1(ref_C1_),
ref_D1(ref_D1_),
epilogue0(epilogue0_),
epilogue1(epilogue1_),
split_k_slices(split_k_slices_) {
}
};
private:
/// Kernel parameters object
typename B2bGemmKernel::Params params_;
public:
/// Constructs the GEMM.
B2bGemm() { }
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const &args) {
if (!kSplitKSerial && args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
Status status = B2bGemmKernel::can_implement(
args.problem_size_0,
args.problem_size_1,
args.ref_A0.non_const_ref(),
args.ref_B0.non_const_ref(),
args.ref_C0.non_const_ref(),
args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(),
args.ref_D1
);
if (status != Status::kSuccess) {
return status;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const &args) {
size_t bytes = 0;
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices);
if (kSplitKSerial && args.split_k_slices > 1) {
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
}
return bytes;
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size_0,
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
args.split_k_slices);
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
// args.problem_size_1,
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
// args.split_k_slices);
if (kSplitKSerial) {
if (args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
size_t bytes = get_workspace_size(args);
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
}
else {
if (args.split_k_slices > 1) {
return Status::kErrorInvalidProblem;
}
}
// Initialize the Params structure
params_ = typename B2bGemmKernel::Params{
args.problem_size_0,
args.problem_size_1,
grid_shape,
args.ref_A0.non_const_ref(),
args.ref_B0.non_const_ref(),
args.ref_C0.non_const_ref(),
args.ref_B1.non_const_ref(),
args.ref_C1.non_const_ref(),
args.ref_D1,
args.epilogue0,
args.epilogue1,
static_cast<int *>(workspace),
};
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const &args, void *workspace = nullptr) {
if (kSplitKSerial && args.split_k_slices > 1) {
if (!workspace) {
return Status::kErrorWorkspaceNull;
}
}
params_.ref_A0.reset(args.ref_A.non_const_ref().data());
params_.ref_B0.reset(args.ref_B.non_const_ref().data());
params_.ref_C0.reset(args.ref_C.non_const_ref().data());
params_.ref_B1.reset(args.ref_B.non_const_ref().data());
params_.ref_C1.reset(args.ref_C.non_const_ref().data());
params_.ref_D1.reset(args.ref_D.data());
params_.output_op_0 = args.epilogue0;
params_.output_op_1 = args.epilogue1;
params_.semaphore = static_cast<int *>(workspace);
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr) {
ThreadblockSwizzle threadblock_swizzle;
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
dim3 block(B2bGemmKernel::kThreadCount, 1, 1);
cudaError_t result;
int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));
if (smem_size >= (48 << 10)) {
result = cudaFuncSetAttribute(Kernel<B2bGemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
result = cudaFuncSetAttribute(
Kernel<B2bGemmKernel>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
if (result != cudaSuccess) {
return Status::kErrorInternal;
}
}
cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);
result = cudaGetLastError();
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
/// Runs the kernel using initialized state.
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace);
if (status == Status::kSuccess) {
status = run(stream);
}
return status;
}
};
} // namespace device
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,74 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
*/
#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h"
#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h"
int run() {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
if (!(props.major * 10 + props.minor >= 75)) {
std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
<< std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
run_nonfused_gemm_f16();
run_fused_gemm_f16();
run_nonfused_gemm_s8();
run_fused_gemm_s8();
#endif
return 0;
}
int main() {
// Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
//
// CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
else {
return run();
}
}

View File

@ -0,0 +1,407 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
>
struct B2bGemm {
using B2bMma = B2bMma_;
using Epilogue = Epilogue_;
using OutputOp0 = typename B2bMma::OutputOp;
using OutputOp1 = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
/// Warp count (concept: GemmShape)
using WarpCount0 = typename B2bMma::WarpCount0;
static int const kThreadCount = 32 * WarpCount0::kCount;
/// Parameters structure
struct Params {
cutlass::gemm::GemmCoord problem_size_0;
cutlass::gemm::GemmCoord problem_size_1;
cutlass::gemm::GemmCoord grid_tiled_shape;
typename B2bMma::IteratorA0::Params params_A0;
typename B2bMma::IteratorA0::TensorRef ref_A0;
typename B2bMma::IteratorB0::Params params_B0;
typename B2bMma::IteratorB0::TensorRef ref_B0;
typename Epilogue::OutputTileIterator::Params params_C0;
typename Epilogue::OutputTileIterator::TensorRef ref_C0;
typename B2bMma::IteratorB1::Params params_B1;
typename B2bMma::IteratorB1::TensorRef ref_B1;
typename Epilogue::OutputTileIterator::Params params_C1;
typename Epilogue::OutputTileIterator::TensorRef ref_C1;
typename Epilogue::OutputTileIterator::Params params_D1;
typename Epilogue::OutputTileIterator::TensorRef ref_D1;
typename OutputOp0::Params output_op_0;
typename OutputOp1::Params output_op_1;
int *semaphore;
int gemm_k_iterations_0;
int gemm_k_size_0;
int gemm_k_iterations_1;
int gemm_k_size_1;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params(): semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
gemm_k_iterations_1(0), gemm_k_size_1(0) { }
CUTLASS_HOST_DEVICE
Params(
cutlass::gemm::GemmCoord const & problem_size_0,
cutlass::gemm::GemmCoord const & problem_size_1,
cutlass::gemm::GemmCoord const & grid_tiled_shape,
typename B2bMma::IteratorA0::TensorRef ref_A0,
typename B2bMma::IteratorB0::TensorRef ref_B0,
typename Epilogue::OutputTileIterator::TensorRef ref_C0,
typename B2bMma::IteratorB1::TensorRef ref_B1,
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
typename Epilogue::OutputTileIterator::TensorRef ref_D1,
typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
int *workspace = nullptr
):
problem_size_0(problem_size_0),
problem_size_1(problem_size_1),
grid_tiled_shape(grid_tiled_shape),
params_A0(ref_A0.layout()),
ref_A0(ref_A0),
params_B0(ref_B0.layout()),
ref_B0(ref_B0),
params_C0(ref_C0.layout()),
ref_C0(ref_C0),
params_B1(ref_B1.layout()),
ref_B1(ref_B1),
params_C1(ref_C1.layout()),
ref_C1(ref_C1),
params_D1(ref_D1.layout()),
ref_D1(ref_D1),
output_op_0(output_op_0),
output_op_1(output_op_1) {
int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK;
int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK;
semaphore = workspace;
}
};
/// Shared memory storage structure
union SharedStorage {
typename B2bMma::B2bMmaSharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
B2bGemm() { }
/// Determines whether kernel satisfies alignment
static Status can_implement(
cutlass::gemm::GemmCoord const & problem_size_0,
cutlass::gemm::GemmCoord const & problem_size_1,
typename B2bMma::IteratorA0::TensorRef ref_A0,
typename B2bMma::IteratorB0::TensorRef ref_B0,
typename Epilogue::OutputTileIterator::TensorRef ref_C0,
typename B2bMma::IteratorB1::TensorRef ref_B1,
typename Epilogue::OutputTileIterator::TensorRef ref_C1,
typename Epilogue::OutputTileIterator::TensorRef ref_D1) {
static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements;
static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B0, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_C0, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_B1, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_C1, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(ref_D1, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) ||
(problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) ||
(problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) ||
(problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) ||
(problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) ||
(problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const &params, SharedStorage &shared_storage) {
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A0{
threadblock_tile_offset.m() * B2bMma::Shape0::kM,
threadblock_tile_offset.k() * params.gemm_k_size_0,
};
cutlass::MatrixCoord tb_offset_B0{
threadblock_tile_offset.k() * params.gemm_k_size_0,
threadblock_tile_offset.n() * B2bMma::Shape0::kN
};
cutlass::MatrixCoord tb_offset_B1{
threadblock_tile_offset.k() * params.gemm_k_size_1,
threadblock_tile_offset.n() * B2bMma::Shape1::kN
};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_0 = min(
params.problem_size_0.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
// Problem size is a function of threadblock index in the K dimension
int problem_size_k_1 = min(
params.problem_size_1.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
// Compute threadblock-scoped matrix multiply-add
// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename B2bMma::IteratorA0 iterator_A0(
params.params_A0,
params.ref_A0.data(),
{params.problem_size_0.m(), problem_size_k_0},
thread_idx,
tb_offset_A0);
typename B2bMma::IteratorB0 iterator_B0(
params.params_B0,
params.ref_B0.data(),
{problem_size_k_0, params.problem_size_0.n()},
thread_idx,
tb_offset_B0);
typename B2bMma::IteratorB1 iterator_B1(
params.params_B1,
params.ref_B1.data(),
{problem_size_k_1, params.problem_size_1.n()},
thread_idx,
tb_offset_B1);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
OutputOp0 output_op_0(params.output_op_0);
// Construct thread-scoped matrix multiply
B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename B2bMma::FragmentC0 src_accum;
typename B2bMma::FragmentC1 accumulators;
src_accum.clear();
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
// Compute threadblock-scoped matrix multiply-add
b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, iterator_B1, src_accum, output_op_0);
}
//
// Epilogue
//
OutputOp1 output_op_1(params.output_op_1);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
//assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * B2bMma::Shape1::kM,
threadblock_tile_offset.n() * B2bMma::Shape1::kN
);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op_1.set_k_partition(threadblock_tile_offset.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C1(
params.params_C1,
params.ref_C1.data(),
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D1(
params.params_D1,
params.ref_D1.data(),
params.problem_size_1.mn(),
thread_idx,
threadblock_offset
);
Epilogue epilogue(
shared_storage.epilogue,
thread_idx,
warp_idx,
lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k()) {
iterator_C1 = iterator_D1;
}
semaphore.wait(threadblock_tile_offset.k());
__threadfence();
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
__threadfence();
semaphore.release(lock);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,296 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
*notice, this list of conditions and the following disclaimer in the
*documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its
*contributors may be used to endorse or promote products derived from this
*software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
*DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
*INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
*DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
*OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
*NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
*EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_pipelined.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "kernel/b2b_gemm.h"
#include "threadblock/default_b2b_mma.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp0,
/// Epilogue output operator
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Beta is zero or not
bool IsBetaZero = false
>
struct DefaultB2bGemm;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for Turing Architecture
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp0,
/// Epilogue output operator
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// If true, kernel is configured to support serial reduction in the epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator
>
struct DefaultB2bGemm<
ElementA, LayoutA, kAlignmentA,
ElementB, LayoutB, kAlignmentB,
ElementC, layout::RowMajor,
ElementAccumulator,
arch::OpClassTensorOp,
arch::Sm75,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
EpilogueOutputOp0,
EpilogueOutputOp1,
ThreadblockSwizzle,
2,
SplitKSerial,
Operator
> {
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm75,
ThreadblockShape0,
ThreadblockShape1,
WarpShape0,
WarpShape1,
InstructionShape,
2,
Operator,
EpilogueOutputOp0
>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue
using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
ThreadblockShape1,
typename B2bMma::Operator1,
kPartitionsK1,
EpilogueOutputOp1,
EpilogueOutputOp1::kCount
>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
};
/// Partial specialization for Turing IMMA Interleaved layout
template <
/// Element type for A matrix operand
typename ElementA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp0,
/// Epilogue output operator
typename EpilogueOutputOp1,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of Interleaved k
int InterleavedK,
/// If true, kernel is configured to support serial reduction in the
/// epilogue
bool SplitKSerial,
/// Operation performed by GEMM
typename Operator,
/// Is Beta zero or not
bool IsBetaZero>
struct DefaultB2bGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
kAlignmentA, ElementB,
layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
int32_t, arch::OpClassTensorOp, arch::Sm75,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {
using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
using ElementAccumulator = int32_t;
/// Define the threadblock-scoped matrix multiply-accumulate
using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
/// Define the epilogue for the 2nd Gemm
using Epilogue = typename cutlass::epilogue::threadblock::
DefaultInterleavedEpilogueTensorOp<
ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
64 / sizeof_bits<ElementC>::value, InterleavedK,
IsBetaZero>::Epilogue;
/// Define the kernel-level GEMM operator.
using B2bGemmKernel = kernel::B2bGemm<B2bMma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
};
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,230 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape0_,
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape1_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy0_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy1_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class B2bMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape0 = Shape0_;
using Shape1 = Shape1_;
///< Policy describing tuning details
using Policy0 = Policy0_;
using Policy1 = Policy1_;
//
// Dependent types
//
/// Warp-level Mma
using Operator0 = typename Policy0::Operator;
using Operator1 = typename Policy1::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm0 = typename Policy0::Operator::Shape;
using WarpGemm1 = typename Policy1::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount0 = GemmShape<Shape0::kM / WarpGemm0::kM,
Shape0::kN / WarpGemm0::kN,
Shape0::kK / WarpGemm0::kK>;
using WarpCount1 = GemmShape<Shape1::kM / WarpGemm1::kM,
Shape1::kN / WarpGemm1::kN,
Shape1::kK / WarpGemm1::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations0 =
(WarpGemm0::kK / Operator0::Policy::MmaShape::kK);
static int const kWarpGemmIterations1 =
(WarpGemm1::kK / Operator1::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
template<
typename Shape_,
typename Policy_
>
class SharedStorage {
public:
//
// Type definitions
//
using Shape = Shape_;
using Policy = Policy_;
using Operator = typename Policy::Operator;
/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
/// Shape of the A matrix operand in shared memory
using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages +
Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB =
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
};
using SharedStorage0 = SharedStorage<Shape0, Policy0>;
using SharedStorage1 = SharedStorage<Shape1, Policy1>;
union B2bMmaSharedStorage {
SharedStorage0 sharedStorage0;
SharedStorage1 sharedStorage1;
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A0 operand from shared memory
typename Operator0::IteratorA warp_tile_iterator_A0_;
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
typename Operator0::IteratorB warp_tile_iterator_B0_;
/// Iterator to load a warp-scoped tile of B0 operand from shared memory
typename Operator1::IteratorB warp_tile_iterator_B1_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
B2bMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
B2bMmaSharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),
warp_tile_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), lane_idx),
warp_tile_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), lane_idx) {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,509 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
#include "threadblock/b2b_mma_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////////////////////
template<int a>
struct chk_val {
static_assert(a==0, "check value");
};
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape0_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorA0_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA0_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB0_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB0_,
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape1_,
/// Iterates over the intermediate accumulator tile
// (concept::MmaTensorOpFragmentIterator)
typename FragmentIteratorA1_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
typename IteratorB1_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB1_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
typename OutputOp_,
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
typename Policy0_,
/// Policy describing tuning details (concept: MmaPipelinedPolicy)
typename Policy1_,
/// Transformation applied to A0 operand
typename TransformA0_ = NumericArrayConverter<
typename SmemIteratorA0_::Element,
typename IteratorA0_::Element,
IteratorA0_::Fragment::kElements>,
///
/// Transformation applied to B0 operand
typename TransformB0_ = NumericArrayConverter<
typename SmemIteratorB0_::Element,
typename IteratorB0_::Element,
IteratorB0_::Fragment::kElements>,
///
/// Transformation applied to B1 operand
typename TransformB1_ = NumericArrayConverter<
typename SmemIteratorB1_::Element,
typename IteratorB1_::Element,
IteratorB1_::Fragment::kElements>,
/// Used for partial specialization
typename Enable = bool
>
class B2bMmaPipelined : public B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2> {
public:
///< Base class
using Base = B2bMmaBase<Shape0_, Shape1_, Policy0_, Policy1_, 2>;
using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
using Policy0 = Policy0_; ///< Policy describing tuning details
using SmemIteratorA0 = SmemIteratorA0_;
using SmemIteratorB0 = SmemIteratorB0_;
using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile
using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
using Policy1 = Policy1_; ///< Policy describing tuning details
using SmemIteratorB1 = SmemIteratorB1_;
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
using TransformA0 = TransformA0_;
using TransformB0 = TransformB0_;
using TransformB1 = TransformB1_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA0 = typename IteratorA0::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB0 = typename IteratorB0::Fragment;
/// Fragment of accumulator tile
using FragmentC0 = typename Policy0::Operator::FragmentC;
/// Warp-level Mma
using Operator0 = typename Policy0::Operator;
/// Fragment of operand B loaded from global memory
using FragmentB1 = typename IteratorB1::Fragment;
/// Fragment of accumulator tile
using FragmentC1 = typename Policy1::Operator::FragmentC;
/// Warp-level Mma
using Operator1 = typename Policy1::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy0::Operator::ArchTag;
/// Complex transform on A0 operand
static ComplexTransform const kTransformA0 = Operator0::kTransformA;
/// Complex transform on B0 operand
static ComplexTransform const kTransformB0 = Operator0::kTransformB;
/// Complex transform on B1 operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA0 = typename Operator0::FragmentA;
using WarpFragmentB0 = typename Operator0::FragmentB;
/// Warp Fragment of operand A1 loaded from accmulator tile
using WarpFragmentA1 = typename FragmentIteratorA1::Fragment;
using WarpFragmentB1 = typename Operator1::FragmentB;
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA0 smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B0 operand to shared memory
SmemIteratorB0 smem_iterator_B0_;
/// Iterator to write threadblock-scoped tile of B1 operand to shared memory
SmemIteratorB1 smem_iterator_B1_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
B2bMmaPipelined(
typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
//These should stay the same across different GEMM layers
int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;
//These may change across different GEMM layers
int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k;
int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0});
this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n});
this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations_0, ///< number of iterations of the mainloop
FragmentC1 &accum, ///< destination accumulator tile
IteratorA0 iterator_A, ///< iterator over A operand in global memory
IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
FragmentC0 const &src_accum, ///< source accumualtor tile
OutputOp output_op_0, ///< epilogue operation after 1st Gemm
TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
//
// Prologue
//
// Perform accumulation in the 'd' output operand
FragmentC0 accum0 = src_accum;
FragmentA0 tb_frag_A;
FragmentB0 tb_frag_B0;
tb_frag_A.clear();
tb_frag_B0.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B0.load(tb_frag_B0);
++iterator_A;
++iterator_B0;
this->smem_iterator_A_.store(tb_frag_A);
this->smem_iterator_B0_.store(tb_frag_B0);
++this->smem_iterator_A_;
++this->smem_iterator_B0_;
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA0 warp_frag_A0[2];
WarpFragmentB0 warp_frag_B0[2];
this->warp_tile_iterator_A0_.set_kgroup_index(0);
this->warp_tile_iterator_B0_.set_kgroup_index(0);
this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
++this->warp_tile_iterator_A0_;
++this->warp_tile_iterator_B0_;
Operator0 warp_mma0;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
if (gemm_k_iterations_0 <= 1) {
iterator_A.clear_mask();
iterator_B0.clear_mask();
}
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
iterator_A.load(tb_frag_A);
//
// Mainloop
//
// Note: The main loop does not support Base::WarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(tb_frag_A);
this->smem_iterator_B0_.store(tb_frag_B0);
__syncthreads();
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
iterator_A.load(tb_frag_A);
++this->smem_iterator_B0_;
++this->smem_iterator_A_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
}
else {
this->warp_tile_iterator_A0_.add_tile_offset(
{0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
this->warp_tile_iterator_B0_.add_tile_offset(
{-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A0_;
++this->warp_tile_iterator_B0_;
if (warp_mma_k == 0) {
iterator_B0.load(tb_frag_B0);
++iterator_A;
++iterator_B0;
// Avoid reading out of bounds if this was the last loop iteration
if (gemm_k_iterations_0 <= 2) {
iterator_A.clear_mask();
iterator_B0.clear_mask();
}
}
warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);
}
}
//2nd Gemm
/// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
//
// Prologue
//
FragmentB1 tb_frag_B1;
tb_frag_B1.clear();
// The last kblock is loaded in the prolog
iterator_B1.load(tb_frag_B1);
++iterator_B1;
this->smem_iterator_B1_.store(tb_frag_B1);
++this->smem_iterator_B1_;
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA1 warp_frag_A1[2];
WarpFragmentB1 warp_frag_B1[2];
//warp_tile_iterator_A1_.set_kgroup_index(0);
this->warp_tile_iterator_B1_.set_kgroup_index(0);
warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
++warp_tile_iterator_A1_;
++this->warp_tile_iterator_B1_;
Operator1 warp_mma1;
smem_write_stage_idx = 1;
int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
// Avoid reading out of bounds
if (gemm_k_iterations_1 <= 1) {
iterator_B1.clear_mask();
}
//
// Mainloop
//
// Note: The main loop does not support Base::WarpGemmIterations == 2.
CUTLASS_PRAGMA_UNROLL
for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
// Write fragments to shared memory
this->smem_iterator_B1_.store(tb_frag_B1);
__syncthreads();
++smem_iterator_B1_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1) {
smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
}
else {
this->warp_tile_iterator_B1_.add_tile_offset(
{-Base::kStages * Policy1::kPartitionsK *
Base::kWarpGemmIterations1,
0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
++warp_tile_iterator_A1_;
++this->warp_tile_iterator_B1_;
if (warp_mma_k == 0) {
iterator_B1.load(tb_frag_B1);
++iterator_B1;
// Avoid reading out of bounds if this was the last loop iteration
if (gemm_k_iterations_1 <= 2) {
iterator_B1.clear_mask();
}
}
warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,289 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
#include "threadblock/b2b_mma_pipelined.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
typename Operator,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false>
struct DefaultB2bMma;
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Epilogue output operator
typename EpilogueOutputOp>
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
kAlignmentB, ElementAccumulator, layout::RowMajor,
OperatorClass, ArchTag,
ThreadblockShape0, ThreadblockShape1,
WarpShape0, WarpShape1,
InstructionShape, 2, Operator, EpilogueOutputOp, false> {
// Define the MmaCore components
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
OperatorClass, 2, Operator>;
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
OperatorClass, 2, Operator>;
// Define iterators over tiles from the A operand
using IteratorA0 =
cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>,
ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>;
// Define iterators over tiles from the B operand
using IteratorB0 =
cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>,
ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>;
// Use fragment iterator for A operand
using AccumulatorLayout = cutlass::layout::ColumnMajor;
using FragmentIteratorA1 =
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
MmaCore1::Shape::kK, //kBlocksColumn
ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp, true>;
// Define iterators over tiles from the B operand
using IteratorB1 =
cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
IteratorB0, typename MmaCore0::SmemIteratorB,
typename MmaCore1::Shape, FragmentIteratorA1,
IteratorB1, typename MmaCore1::SmemIteratorB,
ElementAccumulator, layout::RowMajor,
EpilogueOutputOp,
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for column-major-interleaved output
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape0,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape1,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape0,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape1,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Number of Interleaved K
int InterleavedK>
struct DefaultB2bMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
kAlignmentB, ElementAccumulator,
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, ArchTag,
ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
InstructionShape, 2, Operator, EpilogueOutputOp, true> {
// Define the MmaCore components
using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementAccumulator,
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
true>;
using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementAccumulator,
layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
true>;
static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,
"Alignment must match thread data map's vector length");
static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,
"Alignment must match thread data map's vector length");
// Define iterators over tiles from the A operand
using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore0::Shape::kM, MmaCore0::Shape::kK>, ElementA,
LayoutA, 1, typename MmaCore0::IteratorThreadMapA>;
// Define iterators over tiles from the B operand
using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore0::Shape::kK, MmaCore0::Shape::kN>, ElementB,
LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
// Use fragment iterator for A operand
using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
using FragmentIteratorA1 =
cutlass::gemm::warp::MmaTensorOpFragmentIterator<
cutlass::MatrixShape<MmaCore1::WarpShape::kM, MmaCore1::InstructionShape::kK>, //warp shape
cutlass::MatrixShape<MmaCore0::WarpShape::kM, MmaCore0::WarpShape::kN>, //accumulator shape
MmaCore1::Shape::kK, //kBlocksColumn
ElementAccumulator, ElementA, AccumulatorLayout,
InstructionShape, EpilogueOutputOp, true /*only handle beta=0 for 1st Gemm epilogue*/>;
// Define iterators over tiles from the B operand
using IteratorB1 =
cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore1::Shape::kK, MmaCore1::Shape::kN>,
ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
IteratorB0, typename MmaCore0::SmemIteratorB,
typename MmaCore1::Shape, FragmentIteratorA1,
IteratorB1, typename MmaCore1::SmemIteratorB,
ElementAccumulator, layout::ColumnMajorInterleaved<InterleavedK>,
EpilogueOutputOp,
typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@ -60,6 +60,8 @@ foreach(EXAMPLE
08_turing_tensorop_gemm
10_planar_complex
11_planar_complex_array
12_gemm_bias_relu
13_fused_two_gemms
)
add_subdirectory(${EXAMPLE})

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -52,6 +52,10 @@ struct Sm72 {
struct Sm75 {
static int const kMinComputeCapability = 75;
};
struct Sm80 {
static int const kMinComputeCapability = 80;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch

View File

@ -0,0 +1,60 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Directives related to cache operations
*/
#pragma once
#include "cutlass/cutlass.h"
namespace cutlass {
namespace arch {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Controls PTX cache operations
struct CacheOperation {
enum Kind {
/// Cache at all levels - accessed again
Always,
/// Cache at global level
Global,
/// Streaming - likely to be accessed once
Streaming,
/// Indicates the line will not be used again
LastUse,
/// Don't cache, and fetch again
Volatile,
/// Write back at all coherent levels
WriteBack,
/// Write through to system memory
WriteThrough
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -28,13 +28,271 @@
#pragma once
#include "cutlass/cutlass.h"
namespace cutlass {
namespace arch {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Fragment type to store loaded data
typename AccessType,
/// The bytes of loading
int LoadBytes
>
struct global_load;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
//
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename AccessType
>
struct global_load<AccessType,
32
> {
CUTLASS_DEVICE
global_load(AccessType &D, void const *ptr, bool pred_guard) {
uint4 *data = reinterpret_cast<uint4 *>(&D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %9, 0;\n"
" mov.b32 %0, %10;\n"
" mov.b32 %1, %11;\n"
" mov.b32 %2, %12;\n"
" mov.b32 %3, %13;\n"
" mov.b32 %4, %14;\n"
" mov.b32 %5, %15;\n"
" mov.b32 %6, %16;\n"
" mov.b32 %7, %17;\n"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n"
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w),
"=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w)
: "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y),
"r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y),
"r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16));
}
};
template <typename AccessType
>
struct global_load<AccessType,
16
> {
CUTLASS_DEVICE
global_load(AccessType &D, void const *ptr, bool pred_guard) {
uint4 &data = reinterpret_cast<uint4 &>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" mov.b32 %0, %6;\n"
" mov.b32 %1, %7;\n"
" mov.b32 %2, %8;\n"
" mov.b32 %3, %9;\n"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
}
};
template <typename AccessType
>
struct global_load<AccessType,
8
> {
CUTLASS_DEVICE
global_load(AccessType &D, void const *ptr, bool pred_guard) {
uint2 &data = reinterpret_cast<uint2 &>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
" mov.b32 %0, %4;\n"
" mov.b32 %1, %5;\n"
" @p ld.global.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data.x), "=r"(data.y)
: "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y));
}
};
template <typename AccessType
>
struct global_load<AccessType,
4
> {
CUTLASS_DEVICE
global_load(AccessType &D, void const *ptr, bool pred_guard) {
unsigned &data = reinterpret_cast<unsigned &>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" mov.b32 %0, %3;\n"
" @p ld.global.u32 %0, [%1];\n"
"}\n"
: "=r"(data)
: "l"(ptr), "r"((int)pred_guard), "r"(data));
}
};
template <typename AccessType
>
struct global_load<AccessType,
2
> {
CUTLASS_DEVICE
global_load(AccessType &D, void const *ptr, bool pred_guard) {
uint16_t &data = reinterpret_cast<uint16_t &>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" mov.b16 %0, %3;\n"
" @p ld.global.u16 %0, [%1];\n"
"}\n"
: "=h"(data)
: "l"(ptr), "r"((int)pred_guard), "h"(data));
}
};
template <typename AccessType
>
struct global_load<AccessType,
1
> {
CUTLASS_DEVICE
global_load(AccessType &D, void const *ptr, bool pred_guard) {
if (pred_guard) D = *(reinterpret_cast<AccessType const *>(ptr));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Fragment type to store loaded data
typename AccessType,
/// The bytes of loading
int LoadBytes
>
struct global_store;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
//
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename AccessType>
struct global_store<AccessType, 32> {
CUTLASS_DEVICE
global_store(AccessType const &D, void *ptr, bool pred_guard) {
uint4 const *data = reinterpret_cast<uint4 const *>(&D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n"
"}\n"
:
: "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
"r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16),
"r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w));
}
};
template <typename AccessType>
struct global_store<AccessType, 16> {
CUTLASS_DEVICE
global_store(AccessType const &D, void *ptr, bool pred_guard) {
uint4 const &data = reinterpret_cast<uint4 const &>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
"}\n"
:
: "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), "r"((int)pred_guard));
}
};
template <typename AccessType>
struct global_store<AccessType, 8> {
CUTLASS_DEVICE
global_store(AccessType const &D, void *ptr, bool pred_guard) {
uint2 const &data = reinterpret_cast<uint2 const &>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
" @p st.global.v2.u32 [%0], {%1, %2};\n"
"}\n"
:
: "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard));
}
};
template <typename AccessType>
struct global_store<AccessType, 4> {
CUTLASS_DEVICE
global_store(AccessType const &D, void *ptr, bool pred_guard) {
uint32_t const &data = reinterpret_cast<uint32_t const &>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" @p st.global.u32 [%0], %1;\n"
"}\n"
:
: "l"(ptr), "r"(data), "r"((int)pred_guard));
}
};
template <typename AccessType>
struct global_store<AccessType, 2> {
CUTLASS_DEVICE
global_store(AccessType const &D, void *ptr, bool pred_guard) {
uint16_t const &data = reinterpret_cast<uint16_t const &>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" @p st.global.u16 [%0], %1;\n"
"}\n"
:
: "l"(ptr), "h"(data), "r"((int)pred_guard));
}
};
template <typename AccessType>
struct global_store<AccessType, 1> {
CUTLASS_DEVICE
global_store(AccessType const &D, void *ptr, bool pred_guard) {
if (pred_guard) *(reinterpret_cast<AccessType *>(ptr)) = D;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass
@ -42,4 +300,6 @@ namespace arch {
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "memory_sm75.h"
#include "memory_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -50,20 +50,20 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
//
/////////////////////////////////////////////////////////////////////////////////////////////////
#if ! defined(CUDA_LDMATRIX_SUPPORTED)
#define CUDA_LDMATRIX_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 2))
#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || (__CUDACC_VER_MAJOR__ >= 11)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
#define CUDA_LDMATRIX_ACTIVATED 1
#endif
#if ! defined(CUDA_LDMATRIX_ENABLED)
#define CUDA_LDMATRIX_ENABLED CUDA_LDMATRIX_SUPPORTED
#endif
#if CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
#define CUDA_LDMATRIX_ACTIVATED 1
#define CUDA_LDMATRIX_SUPPORTED 1
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) && (__CUDACC_VER_MAJOR__ > 10)
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED 1
#endif
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED)
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1))
#endif
@ -71,8 +71,9 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED)
#define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED
#endif
*/
#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
extern "C" {
//
// This NVVM intrinsic is subject to change in future versions of CUDA.
@ -85,19 +86,49 @@ inline __device__ void ldsm(Array<unsigned, MatrixCount> & D, void const* ptr);
/////////////////////////////////////////////////////////////////////////////////////////////////
#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
/// CUTLASS helper to get SMEM pointer
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
// the previous internal intrinsics if they are available.
#if (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
//
// This NVVM intrinsic converts an address in shared memory to a plain
// unsigned integer. This is necessary to pass to shared memory instructions
// in inline PTX.
//
// In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2].
//
//__device__ size_t __cvta_generic_to_shared(void* ptr);
/// CUTLASS helper to get SMEM pointer
inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
return __nvvm_get_smem_pointer(const_cast<void *>(ptr));
}
return static_cast<unsigned>(__cvta_generic_to_shared(ptr));
/// CUTLASS helper to get SMEM pointer
inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
return __nvvm_get_smem_pointer(ptr);
}
#elif (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
return __nvvm_get_smem_pointer(ptr);
#elif defined(__CUDA_ARCH__)
uint32_t smem_ptr;
asm(
"{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_ptr) : "l"(ptr));
return smem_ptr;
#else
return 0;
#endif
}
/// CUTLASS helper to get SMEM pointer
inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
return cutlass_get_smem_pointer(const_cast<void *>(ptr));
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <>
@ -235,5 +266,6 @@ inline __device__ void ldsm<layout::ColumnMajor, 4>(
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass

View File

@ -0,0 +1,238 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Architecture-specific operators on memory added for SM80
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/cache_operation.h"
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#define CUDA_CP_ASYNC_ACTIVATED 1
#else
#define CUDA_CP_ASYNC_ACTIVATED 0
#endif
namespace cutlass {
namespace arch {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Initiates an asynchronous copy from global memory to shared memory.
///
/// LDGSTS
///
template <
/// Size of the access in bytes
int SizeInBytes,
/// Cache operation
CacheOperation::Kind cache_op = CacheOperation::Always>
struct cp_async;
/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
/// the entire transfer, zeros are written to SMEM if the guard predicate is false.
///
/// LDGSTS
///
template <
/// Size of the access in bytes
int SizeInBytes,
/// Cache operation
CacheOperation::Kind cache_op = CacheOperation::Always>
struct cp_async_zfill;
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct cp_async<SizeInBytes, CacheOperation::Always> {
/// Copy
CUTLASS_DEVICE
cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
#if CUDA_CP_ASYNC_ACTIVATED
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred_guard),
"r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
#else
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
#endif
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct cp_async_zfill<SizeInBytes, CacheOperation::Always> {
/// Copy with zero fill
CUTLASS_DEVICE
cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
#if CUDA_CP_ASYNC_ACTIVATED
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
int src_in_bytes = (pred_guard ? SizeInBytes : 0);
asm volatile(
"cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
"l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes));
#else
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct cp_async<SizeInBytes, CacheOperation::Global> {
/// Copy
CUTLASS_DEVICE
cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
#if CUDA_CP_ASYNC_ACTIVATED
static_assert(SizeInBytes == 16,
"cp.async only supports CacheOperation::Global when access size is 16B.");
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred_guard),
"r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
#else
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
#endif
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct cp_async_zfill<SizeInBytes, CacheOperation::Global> {
/// Copy with zero fill
CUTLASS_DEVICE
cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
#if CUDA_CP_ASYNC_ACTIVATED
static_assert(SizeInBytes == 16,
"cp.async only supports CacheOperation::Global when access size is 16B.");
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
int src_in_bytes = (pred_guard ? SizeInBytes : 0);
asm volatile(
"cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
"l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes));
#else
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
CUTLASS_DEVICE
void cp_async_fence() {
#if CUDA_CP_ASYNC_ACTIVATED
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
template <int N>
CUTLASS_DEVICE void cp_async_wait() {
#if CUDA_CP_ASYNC_ACTIVATED
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#endif
}
/// Blocks until all previous cp.async.commit_group operations have committed.
template <>
CUTLASS_DEVICE void cp_async_wait<0>() {
#if CUDA_CP_ASYNC_ACTIVATED
asm volatile("cp.async.wait_all;\n" ::);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -51,11 +51,26 @@ struct OpMultiplyAddSaturate;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input is converted to a narrower type (BF16)
struct OpMultiplyAddFastBF16;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input is converted to a narrower type (F16)
struct OpMultiplyAddFastF16;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the complex multiply-add operation
struct OpMultiplyAddComplex;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the gaussian complex multiply-add operation
struct OpMultiplyAddGaussianComplex;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the inner product is defined by (XOR, POPC)
struct OpXorPopc;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -120,8 +120,7 @@ struct Wmma<
////////////////////////////////////////////////////////////////////////////////
//
// WMMA template structure defines nvcuda::wmma::fragments and static assert for
// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1)
// (nvcuda::wmma targetting SASS instruction BMMA)
// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1).
//
////////////////////////////////////////////////////////////////////////////////
template <

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -167,7 +167,7 @@ public:
class const_iterator {
/// Pointer to object
T *ptr_;
const T *ptr_;
public:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

461
include/cutlass/bfloat16.h Normal file
View File

@ -0,0 +1,461 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Defines a proxy class for storing non-standard 16-bit floating point values with
8 bits of exponent and 7 bit of mantissa.
*/
#pragma once
#if !defined(__CUDACC_RTC__)
#include <cmath>
#include <limits>
#include <cstdint>
#endif
#include "cutlass/cutlass.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Floating-point type with 8 bits of exponent and 7 bits of mantissa.
struct alignas(2) bfloat16_t {
//
// Data members
//
/// Storage type
uint16_t storage;
//
// Methods
//
/// Constructs from an unsigned short
CUTLASS_HOST_DEVICE
static bfloat16_t bitcast(uint16_t x) {
bfloat16_t h;
h.storage = x;
return h;
}
/// Default constructor
CUTLASS_HOST_DEVICE
bfloat16_t() { }
/// Floating-point conversion - round toward nearest
CUTLASS_HOST_DEVICE
explicit bfloat16_t(float x) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x));
#else
uint32_t bits = reinterpret_cast<uint32_t &>(x);
if ((bits & 0x7f800000) != 0x7f800000) {
bool mantissa_bit = ((bits & (1 << 16)) != 0);
bool round_bit = ((bits & (1 << 15)) != 0);
bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0);
if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) {
bits += uint32_t(1 << 16);
}
}
else if (bits & ~0xff800000) {
bits = 0x7fffffff;
}
storage = uint16_t((bits >> 16) & 0xffff);
#endif
}
/// Floating-point conversion - round toward nearest
CUTLASS_HOST_DEVICE
explicit bfloat16_t(double x): bfloat16_t(float(x)) {
}
/// Integer conversion - round toward nearest
CUTLASS_HOST_DEVICE
explicit bfloat16_t(int x) {
float flt = static_cast<float>(x);
storage = uint16_t(reinterpret_cast<uint32_t const &>(flt) >> 16);
}
/// Converts to float
CUTLASS_HOST_DEVICE
operator float() const {
unsigned bits = (unsigned(storage) << 16);
return reinterpret_cast<float const &>(bits);
}
/// Converts to float
CUTLASS_HOST_DEVICE
operator double() const {
return double(float(*this));
}
/// Converts to int
CUTLASS_HOST_DEVICE
explicit operator int() const {
return int(float(*this));
}
/// Casts to bool
CUTLASS_HOST_DEVICE
operator bool() const {
return (float(*this) != 0.0f);
}
/// Obtains raw bits
CUTLASS_HOST_DEVICE
uint16_t raw() const {
return storage;
}
/// Returns the sign bit
CUTLASS_HOST_DEVICE
bool signbit() const {
return ((raw() & 0x8000) != 0);
}
/// Returns the biased exponent
CUTLASS_HOST_DEVICE
int exponent_biased() const {
return int((raw() >> 7) & 0x0ff);
}
/// Returns the unbiased exponent
CUTLASS_HOST_DEVICE
int exponent() const {
return exponent_biased() - 127;
}
/// Returns the mantissa
CUTLASS_HOST_DEVICE
int mantissa() const {
return int(raw() & 0x7f);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
CUTLASS_HOST_DEVICE
bool signbit(cutlass::bfloat16_t const& h) {
return h.signbit();
}
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) {
return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fffffff);
}
CUTLASS_HOST_DEVICE
bool isnan(cutlass::bfloat16_t const& h) {
return (h.exponent_biased() == 0x0ff) && h.mantissa();
}
CUTLASS_HOST_DEVICE
bool isfinite(cutlass::bfloat16_t const& h) {
return (h.exponent_biased() != 0x0ff);
}
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t nan_bf16(const char*) {
// NVIDIA canonical NaN
return cutlass::bfloat16_t::bitcast(0x7fff);
}
CUTLASS_HOST_DEVICE
bool isinf(cutlass::bfloat16_t const& h) {
return (h.exponent_biased() == 0x0ff) && !h.mantissa();
}
CUTLASS_HOST_DEVICE
bool isnormal(cutlass::bfloat16_t const& h) {
return h.exponent_biased() && h.exponent_biased() != 0x0ff;
}
CUTLASS_HOST_DEVICE
int fpclassify(cutlass::bfloat16_t const& h) {
int exp = h.exponent_biased();
int mantissa = h.mantissa();
if (exp == 0x0ff) {
if (mantissa) {
return FP_NAN;
}
else {
return FP_INFINITE;
}
}
else if (!exp) {
if (mantissa) {
return FP_SUBNORMAL;
}
else {
return FP_ZERO;
}
}
return FP_NORMAL;
}
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) {
#if defined(__CUDACC_RTC__)
return cutlass::bfloat16_t(sqrtf(float(h)));
#else
return cutlass::bfloat16_t(std::sqrt(float(h)));
#endif
}
CUTLASS_HOST_DEVICE
bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) {
uint16_t a_mag = (reinterpret_cast<uint16_t const &>(a) & 0x7fff);
uint16_t b_sign = (reinterpret_cast<uint16_t const &>(b) & 0x8000);
uint16_t result = (a_mag | b_sign);
return reinterpret_cast<bfloat16_t const &>(result);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Standard Library operations and definitions
//
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace std {
#if !defined(__CUDACC_RTC__)
/// Numeric limits
template <>
struct numeric_limits<cutlass::bfloat16_t> {
static bool const is_specialized = true;
static bool const is_signed = true;
static bool const is_integer = false;
static bool const is_exact = false;
static bool const has_infinity = true;
static bool const has_quiet_NaN = true;
static bool const has_signaling_NaN = false;
static std::float_denorm_style const has_denorm = std::denorm_present;
static bool const has_denorm_loss = true;
static std::float_round_style const round_style = std::round_to_nearest;
static bool const is_iec559 = false;
static bool const is_bounded = true;
static bool const is_modulo = false;
static int const digits = 7;
/// Least positive value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); }
/// Minimum finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); }
/// Maximum finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); }
/// Returns smallest finite value
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); }
};
#endif
} // namespace std
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Arithmetic operators
//
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
CUTLASS_HOST_DEVICE
bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return float(lhs) == float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return float(lhs) != float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return float(lhs) < float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return float(lhs) <= float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return float(lhs) > float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return float(lhs) >= float(rhs);
}
CUTLASS_HOST_DEVICE
bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return bfloat16_t(float(lhs) + float(rhs));
}
CUTLASS_HOST_DEVICE
bfloat16_t operator-(bfloat16_t const& lhs) {
return bfloat16_t(-float(lhs));
}
CUTLASS_HOST_DEVICE
bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return bfloat16_t(float(lhs) - float(rhs));
}
CUTLASS_HOST_DEVICE
bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return bfloat16_t(float(lhs) * float(rhs));
}
CUTLASS_HOST_DEVICE
bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) {
return bfloat16_t(float(lhs) / float(rhs));
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) {
lhs = bfloat16_t(float(lhs) + float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) {
lhs = bfloat16_t(float(lhs) - float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) {
lhs = bfloat16_t(float(lhs) * float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) {
lhs = bfloat16_t(float(lhs) / float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator++(bfloat16_t & lhs) {
float tmp(lhs);
++tmp;
lhs = bfloat16_t(tmp);
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t& operator--(bfloat16_t & lhs) {
float tmp(lhs);
--tmp;
lhs = bfloat16_t(tmp);
return lhs;
}
CUTLASS_HOST_DEVICE
bfloat16_t operator++(bfloat16_t & lhs, int) {
bfloat16_t ret(lhs);
float tmp(lhs);
tmp++;
lhs = bfloat16_t(tmp);
return ret;
}
CUTLASS_HOST_DEVICE
bfloat16_t operator--(bfloat16_t & lhs, int) {
bfloat16_t ret(lhs);
float tmp(lhs);
tmp--;
lhs = bfloat16_t(tmp);
return ret;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// User-defined literals
//
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t operator "" _bf16(long double x) {
return cutlass::bfloat16_t(float(x));
}
CUTLASS_HOST_DEVICE
cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) {
return cutlass::bfloat16_t(int(x));
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -35,6 +35,9 @@
#include "cutlass/half.h"
#include "cutlass/real.h"
#include "cutlass/bfloat16.h"
#include "cutlass/tfloat32.h"
#if !defined(__CUDACC_RTC__)
#include <iosfwd>
#endif
@ -370,6 +373,15 @@ template <typename T>
CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const &z) {
return complex<T>(real(z), -imag(z));
}
/// Indentity transform for non-complex types
template <typename T>
CUTLASS_HOST_DEVICE T conj(T const &z) {
static_assert( !std::is_same<T, cuComplex>::value &&
!std::is_same<T, cuDoubleComplex>::value &&
!std::is_same<T, cutlass::complex<double>>::value &&
!std::is_same<T, cutlass::complex<float>>::value, "May not be a complex data type");
return z;
}
/// Projects the complex number z onto the Riemann sphere
template <typename T>
@ -429,6 +441,7 @@ template <typename T>
struct RealType< complex<T> > {
using Type = T;
CUTLASS_HOST_DEVICE
static complex<T> from_real(double x) {
return complex<T>(static_cast<T>(x));
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -360,6 +360,29 @@ public:
namespace cutlass {
/// Scalar multiplication
template <int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator*(Index s, Coord<Rank, Index> coord) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] *= s;
}
return coord;
}
/// Scalar multiplication
template <int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator*(Coord<Rank, Index> coord, Index s) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] *= s;
}
return coord;
}
/// Scalar division
template <int Rank, typename Index>
CUTLASS_HOST_DEVICE
@ -419,3 +442,4 @@ Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -33,9 +33,14 @@
#include "cutlass/coord.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/gemm/gemm.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
// stream operators for cutlass namespace //
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Rank>
@ -47,8 +52,6 @@ std::ostream& operator<<(std::ostream& out, Coord<Rank> const& coord) {
return out;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
inline
std::istream & operator>>(std::istream &stream, half_t &x) {
float tmp;
@ -62,6 +65,16 @@ std::ostream & operator<<(std::ostream &out, half_t const &x) {
return out << float(x);
}
inline
std::ostream & operator<<(std::ostream &out, bfloat16_t const &x) {
return out << float(x);
}
inline
std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) {
return out << float(x);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to enable formatted printing of CUTLASS scalar types to an ostream
@ -98,7 +111,54 @@ inline std::ostream &operator<<(std::ostream &out, ScalarIO<uint8_t> const &scal
return out << unsigned(scalar.value);
}
/// Default printing to ostream for MatrixShape
template <int Row, int Column>
inline
std::ostream & operator<<(std::ostream &out, cutlass::MatrixShape<Row, Column> const &matrix_shape) {
out << "cutlass::MatrixShape::(kRow, kColumn) {"
<< cutlass::MatrixShape<Row,Column>::kRow <<","
<< cutlass::MatrixShape<Row,Column>::kColumn <<"}";
return out;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// stream operators for cutlass::gemm namespace //
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace gemm {
/// Default printing to ostream for GemmShape
template <int M, int N, int K>
inline
std::ostream & operator<<(std::ostream &out, cutlass::gemm::GemmShape<M,N,K> const &gemm_shape) {
out << "cutlass::GemmShape::(kM, kN, kK) {"
<< cutlass::gemm::GemmShape<M,N,K>::kM <<","
<< cutlass::gemm::GemmShape<M,N,K>::kN <<","
<< cutlass::gemm::GemmShape<M,N,K>::kK << "}";
return out;
}
} //namespace gemm
///////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////////////
// stream operators for cutlass::layout namespace //
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace layout {
/// Default printing to ostream for PitchLinearShape
template < int Contiguous, int Strided>
inline
std::ostream & operator<<(std::ostream &out, cutlass::layout::PitchLinearShape<Contiguous, Strided> const &pitch_linear_shape) {
out << "cutlass::layout::PitchLinearShape::(kContiguous, kStrided) {"
<< cutlass::layout::PitchLinearShape<Contiguous,Strided>::kContiguous <<","
<< cutlass::layout::PitchLinearShape<Contiguous,Strided>::kStrided <<"}";
return out;
}
} //namespace layout
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -55,6 +55,8 @@ enum class Status {
kErrorNotSupported, ///< Operation is not supported on current device.
kErrorWorkspaceNull, ///< The given workspace is null when it is required to be non-null.
kErrorInternal, ///< An error within CUTLASS occurred.
kErrorArchMismatch, ///< CUTLASS runs on a device that it was not compiled for.
kErrorInsufficientDriver, ///< CUTLASS runs with a driver that is too old.
kInvalid ///< Status is unspecified.
};
@ -78,6 +80,10 @@ static char const* cutlassGetStatusString(cutlass::Status status) {
return "Error Workspace Null";
case cutlass::Status::kErrorInternal:
return "Error Internal";
case cutlass::Status::kErrorInsufficientDriver:
return "Error Insufficient Driver";
case cutlass::Status::kErrorArchMismatch:
return "Erroor Architecture Mismatch";
case cutlass::Status::kInvalid: break;
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -0,0 +1,119 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief This extends the contents of cutlass/functional.h with frequently used activation functions.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/complex.h"
#include "cutlass/array.h"
#include "cutlass/half.h"
#include "cutlass/functional.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// ReLu operator - propagates NaNs
template <typename T>
struct ReLu {
CUTLASS_HOST_DEVICE
T operator()(T const & threshold, T const &value) const {
if (value < threshold) {
value = threshold;
}
return value;
}
};
template <typename T, int N>
struct ReLu<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
Array<T, N> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
T value = frag[i];
if (value < threshold) {
value = threshold;
}
result[i] = value;
}
return result;
}
};
// Sigmoid operator
template <typename T>
struct Sigmoid {
CUTLASS_HOST_DEVICE
T operator()(T const &scalar) const {
return T(1) / (T(1) + exp(-scalar));
}
};
template <>
struct Sigmoid<float> {
CUTLASS_HOST_DEVICE
float operator()(float const &scalar) const {
return 1.0f / (1.0f + expf(-scalar));
}
};
template <typename T, int N>
struct Sigmoid<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> y;
Sigmoid<T> sigmoid_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < int(rhs.size()); ++i) {
y[i] = sigmoid_op(rhs[i]);
}
return y;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -101,7 +101,7 @@ public:
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source,
FragmentOutput const &source = FragmentOutput(),
ElementCompute uniform = ElementCompute(0)) const {
// Convert to destination numeric type

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -165,6 +165,28 @@ public:
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -178,6 +178,40 @@ public:
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
minimum<ComputeFragment> min_accumulator;
maximum<ComputeFragment> max_accumulator;
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
/// Clamping constant value
ElementCompute const kClamp =
ElementCompute((1U << (sizeof_bits<ElementOutput>::value - 1)) - 1);
intermediate = max_accumulator(intermediate, -kClamp - ElementCompute(1));
intermediate = min_accumulator(intermediate, kClamp);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -278,7 +312,7 @@ public:
beta_ = ElementCompute(1);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
@ -316,6 +350,37 @@ public:
return destination_converter(scaled_accumulator);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Compute linear scaling in floating point
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_accumulator;
// Float min-max
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
}
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
return destination_converter(scaled_accumulator);
}
};
#endif // Conditional guards to enable partial specialization for packed integers
@ -410,7 +475,7 @@ class FastLinearCombinationClamp {
beta_ = ElementCompute(1);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator,
@ -453,6 +518,41 @@ class FastLinearCombinationClamp {
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
FastNumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Compute linear scaling in floating point
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
minimum<ComputeFragment> min_accumulator;
maximum<ComputeFragment> max_accumulator;
// Float min-max
intermediate = mul_accumulator(alpha_, converted_accumulator);
/// Clamping constant value
ElementCompute const kClamp =
ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
intermediate = max_accumulator(intermediate, -kClamp);
intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
// Convert to destination numeric type
FastNumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
return destination_converter(intermediate);
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -185,6 +185,39 @@ public:
destination_converter(intermediate.real),
destination_converter(intermediate.imag));
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator(
accumulator_converter(accumulator.real),
accumulator_converter(accumulator.imag));
// Perform binary operations
ComputeFragment intermediate;
multiplies<Array<ElementCompute, kCount> > mul_op;
multiply_add<Array<ElementCompute, kCount> > mul_add_op;
// complex multiply-add: I = alpha * AB + I
intermediate.real = mul_add_op(alpha_.real(), converted_accumulator.real);
intermediate.imag = mul_add_op(alpha_.real(), converted_accumulator.imag);
intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real);
intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return FragmentOutput(
destination_converter(intermediate.real),
destination_converter(intermediate.imag));
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -23,8 +23,7 @@
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination operations used by epilogues. Values are clamped before
converting to the output element type.
\brief Functor performing linear combination with a maximum operation used by epilogues.
*/
#pragma once
@ -34,6 +33,7 @@
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -43,8 +43,7 @@ namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements then clamps the output before
/// converting to the output element type.
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
@ -75,10 +74,10 @@ public:
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute threshold; ///< Relu threshold
ElementCompute threshold; ///< minimum value that is output
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
ElementCompute const *threshold_ptr; ///< pointer to threshold scalar - if not null, loads from memory
//
// Methods
//
@ -87,16 +86,17 @@ public:
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
threshold(ElementCompute(0)),
threshold(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
beta_ptr(nullptr),
threshold_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta,
ElementCompute threshold = ElementCompute(0)
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
ElementCompute threshold = ElementCompute(0)
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr), threshold_ptr(nullptr) {
}
@ -104,8 +104,8 @@ public:
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr,
ElementCompute threshold = ElementCompute(0)
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
ElementCompute const *threshold_ptr = nullptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold_ptr(threshold_ptr) {
}
};
@ -128,7 +128,7 @@ public:
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
threshold_ = params.threshold;
threshold_ = (params.threshold_ptr ? *params.threshold_ptr : params.threshold);
}
/// Returns true if source is needed
@ -144,13 +144,12 @@ public:
beta_ = ElementCompute(1);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source,
ElementCompute uniform = ElementCompute(0)) const {
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
@ -160,18 +159,44 @@ public:
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
maximum<ComputeFragment> max_accumulator;
ReLu<ComputeFragment> relu;
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = max_accumulator(intermediate, threshold_);
// Compute threshold optionally
intermediate = relu(threshold_, intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
ReLu<ComputeFragment> relu;
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
// Compute threshold optionally
intermediate = relu(threshold_, intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
@ -180,24 +205,24 @@ public:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Conditional guards to enable partial specialization for packed integers
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && \
((__CUDACC_VER_MAJOR__ > 10) || \
((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && ((__CUDACC_VER_MAJOR__ > 10) || ((__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)))
/// Applies a linear combination operator to an array of elements then clamps the output before
/// converting to the output element type.
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
/// Special handling for int types
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
FloatRoundStyle Round
>
class LinearCombinationRelu<ElementOutput_, Count, int, float, Round> {
class LinearCombinationRelu <ElementOutput_, Count, int, float, Round> {
public:
using ElementOutput = ElementOutput_;
@ -217,10 +242,10 @@ public:
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute threshold; ///< Relu threshold
ElementCompute threshold; ///< minimum value that is output
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
ElementCompute const *threshold_ptr; ///< pointer to threshold scalar - if not null, loads from memory
//
// Methods
//
@ -229,16 +254,17 @@ public:
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
threshold(ElementCompute(0)),
threshold(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
beta_ptr(nullptr),
threshold_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta,
ElementCompute threshold = ElementCompute(0)
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
ElementCompute threshold = ElementCompute(0)
): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr), threshold_ptr(nullptr) {
}
@ -246,8 +272,8 @@ public:
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr,
ElementCompute threshold = ElementCompute(0)
): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
ElementCompute const *threshold_ptr = nullptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), threshold_ptr(threshold_ptr) {
}
};
@ -270,7 +296,7 @@ public:
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
threshold_ = params.threshold;
threshold_ = (params.threshold_ptr ? *params.threshold_ptr : params.threshold);
}
/// Returns true if source is needed
@ -286,13 +312,12 @@ public:
beta_ = ElementCompute(1);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source,
ElementCompute uniform = ElementCompute(0)) const {
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
@ -302,21 +327,16 @@ public:
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
maximum<ComputeFragment> max_accumulator;
ReLu<FragmentAccumulator> relu;
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
// Clamp to theshold
intermediate = max_accumulator(intermediate, threshold_);
// Convert back to accumulator data type
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
CUTLASS_PRAGMA_UNROLL
@ -324,8 +344,46 @@ public:
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
}
// Convert to destination numeric type and pack
NumericArrayConverter<ElementOutput, ElementAccumulator, kCount, Round> destination_converter;
// Compute threshold optionally
scaled_accumulator = relu(threshold_, scaled_accumulator);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
return destination_converter(scaled_accumulator);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
ReLu<FragmentAccumulator> relu;
intermediate = mul_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
// Convert floats back to INT
FragmentAccumulator scaled_accumulator;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
scaled_accumulator[i] = static_cast<int>(intermediate[i]);
}
// Compute threshold optionally
scaled_accumulator = relu(threshold_, scaled_accumulator);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
return destination_converter(scaled_accumulator);
}
@ -338,3 +396,6 @@ public:
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,206 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/epilogue/thread/activation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
int Count, ///< Number of elements computed per operation
typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationSigmoid {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
/// Host-constructable parameters structure
struct Params {
ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory
//
// Methods
//
CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
}
CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
}
};
private:
//
// Data members
//
ElementCompute alpha_;
ElementCompute beta_;
public:
/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationSigmoid(Params const &params) {
alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return beta_ != ElementCompute(0);
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition) {
if (k_partition) {
beta_ = ElementCompute(1);
}
}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
Sigmoid<ComputeFragment> sigmoid;
intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
intermediate = sigmoid(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_accumulator;
Sigmoid<ComputeFragment> sigmoid;
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
intermediate = sigmoid(intermediate);
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
return destination_converter(intermediate);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -45,6 +45,7 @@
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
@ -76,6 +77,7 @@ template <
/// Elements accessed by inner-most loop of AccumulatorFragmentIterator::load()
int ElementsPerAccess,
/// Multiply-add operator
/// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex)
typename Operator_ = arch::OpMultiplyAddComplex>
struct DefaultEpilogueComplexTensorOp {
@ -146,6 +148,91 @@ struct DefaultEpilogueComplexTensorOp {
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization and defines sensible defaults for epilogues for complex*complex case
// 3 real-valued mma operations (Gaussian Complex)
// A = (ar + j ai), B = (br +j bi), D = AB
// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi)
// D = dr + j di = (P1 - P3) + j (P1 + P2)
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
typename Shape_,
typename WarpMmaTensorOp_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess
>
struct DefaultEpilogueComplexTensorOp <Shape_, WarpMmaTensorOp_, PartitionsK,
OutputOp_, ElementsPerAccess,
arch::OpMultiplyAddGaussianComplex> {
using Shape = Shape_;
using WarpMmaTensorOp = WarpMmaTensorOp_;
static int const kPartitionsK = PartitionsK;
using OutputOp = OutputOp_;
static int const kElementsPerAccess = ElementsPerAccess;
using Operator = arch::OpMultiplyAddGaussianComplex;
using ElementOutput = typename OutputOp::ElementOutput;
using LayoutC = typename WarpMmaTensorOp::LayoutC;
using ElementAccumulator = typename WarpMmaTensorOp::ElementC;
//
// Thread map
//
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
Shape,
typename WarpMmaTensorOp::Shape,
kPartitionsK,
ElementOutput,
kElementsPerAccess
>::Type;
using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<
OutputTileThreadMap,
ElementOutput
>;
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorGaussianComplexTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename WarpMmaTensorOp::Policy::Operator::ElementC,
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
LayoutC
>;
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
ElementAccumulator,
LayoutC
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
typename OutputTileThreadMap::CompactedThreadMap,
ElementAccumulator
>;
/// Hard-coded padding elements added
using Padding = cutlass::MatrixShape<0, 0>;
//
// Define the epilogue
//
using Epilogue = cutlass::epilogue::threadblock::Epilogue<
Shape,
WarpMmaTensorOp,
kPartitionsK,
OutputTileIterator,
AccumulatorFragmentIterator,
WarpTileIterator,
SharedLoadIterator,
OutputOp,
Padding
>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -148,6 +148,44 @@ struct DefaultEpiloguePlanarComplex<
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues.
template <
typename ThreadblockShape_,
typename WarpMmaOperator_,
int PartitionsK,
typename OutputOp_,
int ElementsPerAccess
>
struct DefaultEpiloguePlanarComplex<
ThreadblockShape_,
WarpMmaOperator_,
arch::OpClassTensorOp,
arch::Sm80,
PartitionsK,
OutputOp_,
ElementsPerAccess> {
using RealEpilogue = DefaultEpilogueTensorOp<
ThreadblockShape_,
WarpMmaOperator_,
PartitionsK,
OutputOp_,
ElementsPerAccess
>;
using Epilogue = EpiloguePlanarComplex<
ThreadblockShape_,
WarpMmaOperator_,
PartitionsK,
typename RealEpilogue::OutputTileIterator,
typename RealEpilogue::AccumulatorFragmentIterator,
typename RealEpilogue::WarpTileIterator,
typename RealEpilogue::SharedLoadIterator,
OutputOp_,
typename RealEpilogue::Padding
>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues.

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -39,6 +39,7 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -39,16 +39,20 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
@ -61,6 +65,177 @@ namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <
typename ElementOutput,
typename ElementAccumulator,
int ElementsPerAccess,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
WarpShape,
InstructionShape,
ElementAccumulator,
layout::RowMajor
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
ThreadMap,
ElementAccumulator
>;
};
/// Partial specialization for half <= float x 8 epilogues avoids shared memory bank conflicts.
template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp<
half_t,
float,
8,
ThreadblockShape,
WarpShape,
InstructionShape,
ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
WarpShape,
InstructionShape,
float,
32,
16,
8,
8
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
ThreadMap,
float,
32,
16,
8,
8
>;
};
/// Partial specialization for int8_t x 16 <= int32_t x 16 epilogues avoids shared memory bank conflicts.
template <
int K,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp<
int8_t,
int32_t,
16,
gemm::GemmShape<128, 128, K>,
gemm::GemmShape<64, 64, K>,
InstructionShape,
ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
gemm::GemmShape<64, 64, K>,
InstructionShape,
int32_t,
32,
8,
16,
8
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
ThreadMap,
int32_t,
32,
8,
16,
8
>;
};
/// Partial specialization for int8_t x 8 <= int32_t x 8 epilogues avoids shared memory bank conflicts.
template <
int K,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp<
int8_t,
int32_t,
8,
gemm::GemmShape<128, 64, K>,
gemm::GemmShape<64, 32, K>,
InstructionShape,
ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
gemm::GemmShape<64, 32, K>,
InstructionShape,
int32_t,
32,
8,
8,
8
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
ThreadMap,
int32_t,
32,
8,
8,
8
>;
};
/// Partial specialization for int8_t x 8 <= int32_t x 8 epilogues avoids shared memory bank conflicts.
template <
int K,
typename InstructionShape,
typename ThreadMap
>
struct DefaultIteratorsTensorOp<
int8_t,
int32_t,
8,
gemm::GemmShape<64, 64, K>,
gemm::GemmShape<32, 32, K>,
InstructionShape,
ThreadMap> {
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
gemm::GemmShape<32, 32, K>,
InstructionShape,
int32_t,
32,
8,
8,
8
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
ThreadMap,
int32_t,
32,
8,
8,
8
>;
};
} // namespace detail
////////////////////////////////////////////////////////////////////////////////
/// Defines sensible defaults for epilogues for TensorOps.
template <
typename Shape_,
@ -98,25 +273,33 @@ struct DefaultEpilogueTensorOp {
ElementOutput
>;
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename WarpMmaTensorOp::Policy::Operator::ElementC,
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
LayoutC
>;
using AccumulatorFragmentIterator = typename std::conditional<is_complex<ElementOutput>::value,
cutlass::epilogue::warp::FragmentIteratorComplexTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename WarpMmaTensorOp::Policy::Operator::ElementC,
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
LayoutC>,
cutlass::epilogue::warp::FragmentIteratorTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename WarpMmaTensorOp::Policy::Operator::ElementC,
typename WarpMmaTensorOp::Policy::Operator::FragmentC,
LayoutC> >::type;
using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp<
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
/// Support several implementations depending on structure of epilogue
using DefaultIterators = detail::DefaultIteratorsTensorOp<
ElementOutput,
ElementAccumulator,
LayoutC
kElementsPerAccess,
Shape,
typename WarpMmaTensorOp::Shape,
typename WarpMmaTensorOp::Policy::Operator::Shape,
typename OutputTileThreadMap::CompactedThreadMap
>;
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<
typename OutputTileThreadMap::CompactedThreadMap,
ElementAccumulator
>;
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
/// Hard-coded padding elements added
using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits<ElementAccumulator>::value * 4>;
@ -184,6 +367,7 @@ struct DefaultInterleavedEpilogueTensorOp {
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -39,6 +39,7 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -39,6 +39,7 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -146,54 +146,6 @@ struct DefaultInterleavedThreadMapTensorOp {
////////////////////////////////////////////////////////////////////////////////
/// Defines the optimal thread map for TensorOp accumulator layouts
template <typename ThreadblockShape_, typename WarpShape_, int PartitionsK,
typename Element_, int ElementsPerAccess, int InterleavedK>
struct DefaultInterleavedConvThreadMapTensorOp {
using ThreadblockShape = ThreadblockShape_;
using WarpShape = WarpShape_;
static int const kPartitionsK = PartitionsK;
using Element = Element_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kInterleavedK = InterleavedK;
//
// Definitions
//
struct Detail {
/// Tensor Operations fundamentally perform operations on 8 rows
static int const kTensorOpRows = 8;
static int const kWarpSize = 32;
static_assert(!(ThreadblockShape::kM % WarpShape::kM) &&
!(ThreadblockShape::kM % WarpShape::kM),
"Divisibility");
/// Number of warps
using WarpCount =
gemm::GemmShape<ThreadblockShape::kM / WarpShape::kM,
ThreadblockShape::kN / WarpShape::kN, kPartitionsK>;
/// Number of participating threads
static int const kThreads = WarpCount::kCount * kWarpSize;
};
//
// ThreadMap
//
/// ThreadMap to be used by epilogue::MaskedTileIterator satisfying concept
/// InterleavedOutputTileThreadMap
using Type = InterleavedConvOutputTileThreadMap<
MatrixShape<Detail::WarpCount::kM, Detail::WarpCount::kN>,
MatrixShape<WarpShape::kM / Detail::kTensorOpRows,
WarpShape::kN / InterleavedK>,
Detail::kThreads, kElementsPerAccess, sizeof_bits<Element>::value>;
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -175,15 +175,106 @@ public:
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
int64_t imag_stride_dest = 0, ///< Arguments required for planar complex case - not used in real-valued case
int64_t imag_stride_src = 0) { ///<
typename OutputTileIterator::Fragment source_fragment;
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
if (!output_op.is_source_needed()) {
source_iterator.clear_mask();
compute_source_not_needed_(output_op, destination_iterator, accumulators);
}
else {
compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator);
}
}
private:
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators) { ///< Complete warp-level accumulator tile
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
//
// Convert and store fragment
//
__syncthreads();
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
this->warp_tile_iterator_.store(accum_fragment);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
if (kPartitionsK > 1)
{
plus <typename SharedLoadIterator::Fragment> add_fragments;
const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment[0]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const &output_op, ///< Output operator
OutputTileIterator destination_iterator, ///< Tile iterator for destination
AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile
OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles)
typename OutputTileIterator::Fragment source_fragment;
source_fragment.clear();
@ -265,8 +356,6 @@ public:
}
}
private:
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
@ -294,6 +383,30 @@ private:
output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
typename OutputTileIterator::Fragment &output_fragment,
OutputOp const &output_op, ///< Output operator
typename SharedLoadIterator::Fragment const &aligned_accum_fragment) {
OutputAccessType *output_frag_ptr =
reinterpret_cast<OutputAccessType *>(&output_fragment);
AccumulatorAccessType const *compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
int const kOutputOpIterations =
OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = output_op(compute_frag_ptr[i]);
}
}
};
////////////////////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -490,67 +490,6 @@ struct InterleavedOutputTileThreadMap {
////////////////////////////////////////////////////////////////////////////////
/// Template metaprogram for partitioning a 4D interleaved layout across warps
/// to achieve several performance objectives:
///
/// - coalesced memory accesses in units of 64 Byte lines
/// - minimal address arithmetic
/// - minimal predicate calculations
///
template <typename WarpCount_, typename Iterations_, int Threads,
int ElementsPerAccess, int ElementSize>
struct InterleavedConvOutputTileThreadMap {
using WarpCount = WarpCount_;
static int const kWarpSize = 32;
static int const kThreads = Threads;
static int const kWarpCount = kThreads / kWarpSize;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kElementSize = ElementSize;
//
// Metaprogram computation
//
struct Detail {};
//
// Output
//
using Iterations = Iterations_;
using Delta = MatrixShape<kWarpSize / 4, 4 * kElementsPerAccess>;
/// Initial offset function
CUTLASS_HOST_DEVICE
static MatrixCoord initial_offset(int thread_idx) {
int warp_idx = thread_idx / kWarpSize;
int lane_idx = thread_idx % kWarpSize;
// Compute warp location
MatrixCoord warp_footprint{
Delta::kRow * Iterations::kRow,
Delta::kColumn * Iterations::kColumn,
};
MatrixCoord warp_offset{warp_idx % WarpCount::kRow,
warp_idx / WarpCount::kRow};
// Compute per-lane offset
MatrixCoord thread_offset_in_warp{lane_idx / 4,
(lane_idx % 4) * kElementsPerAccess};
MatrixCoord thread_offset_in_threadblock_tile =
warp_footprint * warp_offset + thread_offset_in_warp;
return thread_offset_in_threadblock_tile;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass

Some files were not shown because too many files have changed in this diff Show More