705 lines
30 KiB
Markdown
705 lines
30 KiB
Markdown

|
|
|
|
[README](../../README.md#documentation) > **CUTLASS 3.0 GEMM API**
|
|
|
|
# CUTLASS 3.0 GEMM API
|
|
|
|
CUTLASS presents a uniform programming model
|
|
for matrix multiply-accumulate (MMA) operations
|
|
at different levels of the GPU system hierarchy.
|
|
CUTLASS 3.0 has GEMM APIs corresponding to the following levels
|
|
in order of highest to the lowest level.
|
|
|
|
1. Device
|
|
2. Kernel
|
|
3. Collective
|
|
4. Tiled MMA and Copy
|
|
5. Atom
|
|
|
|
This document will cover the first three levels in detail:
|
|
Device, Kernel, and Collective.
|
|
It also briefly discusses the Tiled MMA/Copy and Atom level,
|
|
and then refers readers to CuTe's tutorial for more information.
|
|
|
|
# CUTLASS GEMM Model
|
|
|
|
CUTLASS implements algorithms that express
|
|
the classical "triply nested loop" GEMM algorithm
|
|
with a tiled structure mirroring the above hierarchy.
|
|
|
|
The following pseudocode describes the model for a GEMM kernel
|
|
targeting a warp-synchronous matrix multiply instruction like `mma.sync.`
|
|
The entire operation is referred to as "Gemm,"
|
|
as it is assumed that an epilogue operation
|
|
performs the general matrix update similar to BLAS.
|
|
This is pseudocode and is only meant to illustrate which parts of the layers
|
|
correspond to the inner or outer loops of the GEMM.
|
|
|
|
```c++
|
|
// cutlass::gemm::kernel::GemmUniversal: ClusterTileM and ClusterTileN loops
|
|
// are either rasterized by the hardware or scheduled by the kernel in persistent kernels.
|
|
// Parallelism over thread block clusters
|
|
for (int cluster_m = 0; cluster_m < GemmM; cluster_m += ClusterTileM) {
|
|
for (int cluster_n = 0; cluster_n < GemmN; cluster_n += ClusterTileN) {
|
|
|
|
// cutlass::gemm::collective::CollectiveMma: mainloop that iterates over all k-tiles
|
|
// No loop unrolling is performed at this stage
|
|
for (int k_tile = 0; k_tile < size<2>(gmem_tensor_A); k_tile++) {
|
|
|
|
// loops inside cute::gemm(tiled_mma, a, b, c); Dispatch 5: (V,M,K) x (V,N,K) => (V,M,N)
|
|
// TiledMma uses the hardware instruction provided through its Mma_Atom
|
|
// TiledMma's atom layout, value layout, and permutations define the iteration order
|
|
for (int tiled_mma_k = 0; tiled_mma_k < size<2>(A); tiled_mma_k++) {
|
|
for (int tiled_mma_m = 0; tiled_mma_m < size<1>(A); tiled_mma_m++) {
|
|
for (int tiled_mma_n = 0; tiled_mma_n < size<1>(B); tiled_mma_n++) {
|
|
|
|
// TiledMma's vector mode dispatches to the underlying instruction.
|
|
mma.call(d, a, b, c);
|
|
} // tiled_mma_n
|
|
} // tiled_mma_m
|
|
} // tiled_mma_k
|
|
} // k_tile mainloop
|
|
} // cluster_m
|
|
} // cluster_n
|
|
```
|
|
|
|
The first three nested `for` loops
|
|
correspond to parallelism over thread block clusters.
|
|
The code does not actually express them as explicit `for` loops.
|
|
Instead, the parallelization scheme over tiles
|
|
is implied by CUDA grid launch semantics.
|
|
However, for persistent kernels,
|
|
these three loops are expressed in the source code
|
|
as a single `while` loop that queries the
|
|
[work tile scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp)
|
|
for problem tiles on which to compute.
|
|
|
|
Inside the three nested `for` loops,
|
|
one finds code that pulls matrix tiles
|
|
from global memory into more "local" memory
|
|
(like shared memory or registers)
|
|
and computes MMAs.
|
|
These tiled copy and tiled mma iterations are generally
|
|
fully static and get fully unrolled.
|
|
|
|
# CUTLASS GEMM Components
|
|
|
|
CUTLASS expresses the above loop nest
|
|
with the following components which are specialized for
|
|
data type, layout, and math instruction.
|
|
|
|
| API level | API Class and/or function names |
|
|
| --- | --- |
|
|
| Device | `cutlass::gemm::device::GemmUniversalAdapter` |
|
|
| Kernel | `cutlass::gemm::kernel::GemmUniversal` |
|
|
| Collective | `cutlass::gemm::collective::CollectiveMma` <br /> `cutlass::epilogue::collective::DefaultEpilogue` <br /> `cutlass::epilogue::collective::Epilogue` <br /> |
|
|
| Tiled (MMA and Copy) | `cute::TiledMma` and `cute::TiledCopy` <br /> `cute::gemm()` and `cute::copy()` |
|
|
| Atom | `cute::Mma_Atom` and `cute::Copy_Atom` |
|
|
|
|
In CUTLASS 3.0, we assemble kernels
|
|
by first composing a collective mainloop and collective epilogue
|
|
together at the kernel layer,
|
|
and then wrapping them with a host-side adapter
|
|
to form a GEMM handle to that kernel.
|
|
|
|
The following sections describe these components
|
|
in the order a user should instantiate them
|
|
in order to assemble a kernel. This order is
|
|
|
|
1. assemble the required collective mainloop and epilogues,
|
|
|
|
2. compose them together to build a kernel type, and
|
|
|
|
3. wrap up the kernel with a device layer adapter.
|
|
|
|
This order is also reflected in the [CUTLASS 3.0 Hopper kernel examples](/examples/48_hopper_warp_specialized_gemm) as seen in the excerpt below.
|
|
|
|
```c++
|
|
// Step 1: Generate the required collective layer mainloop specialization
|
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
|
ArchTag, OperatorClass,
|
|
ElementA, LayoutA, AlignmentA,
|
|
ElementB, LayoutB, AlignmentB,
|
|
ElementAccumulator,
|
|
TilesShape, ClusterShape,
|
|
cutlass::gemm::collective::StageCountAuto,
|
|
cutlass::gemm::collective::KernelScheduleAuto
|
|
>::CollectiveOp;
|
|
|
|
// Step 2: Specify the collective layer epilogue type
|
|
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
|
|
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
|
cutlass::gemm::TagToStrideC_t<LayoutC>,
|
|
cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>>;
|
|
|
|
// Step 3: Compose the mainloop and epilogue together at the kernel layer
|
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
|
cute::Shape<int,int,int,int>, // ProblemShape [M,N,K,L]
|
|
CollectiveMainloop,
|
|
CollectiveEpilogue
|
|
>;
|
|
|
|
// Step 4: Wrap up the kernel::GemmUniversal kernel class
|
|
// with the device adapter to obtain a host-side handle to the kernel
|
|
using GemmHandle = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
|
```
|
|
|
|
Towards the end, we also briefly cover CuTe's tiled mma and copy as well as the atom layer APIs,
|
|
before redirecting users to CuTe-specific documentation for further details.
|
|
|
|
## Collective API
|
|
|
|
A Collective is "the largest collection of threads
|
|
onto which mma atoms and copy atoms are tiled."
|
|
That is, it is the largest number of threads in a grid
|
|
that can cooperate by leveraging hardware features
|
|
for accelerated communication and synchronization.
|
|
These hardware features include
|
|
|
|
* asynchronous array copy
|
|
(e.g., from global memory to shared memory);
|
|
|
|
* MMA instructions
|
|
for small tiles that live in shared memory;
|
|
|
|
* synchronization operations for clusters,
|
|
thread blocks, and/or warps; and/or
|
|
|
|
* hardware acceleration (such as barriers)
|
|
for ensuring that data dependencies
|
|
between asynchronous operations are met.
|
|
|
|
A Collective uses the `TiledMma` and `TiledCopy` API (see below)
|
|
to access operations that copy and perform MMA on tiles.
|
|
|
|
Different units of parallelism
|
|
(e.g., threads, warps, or thread blocks)
|
|
in a Collective might have different roles.
|
|
For example, in "warp-specialized" algorithms,
|
|
some warps may be responsible for copying data,
|
|
while others may be responsible for computation.
|
|
Nevertheless, the different units of parallelism
|
|
still need to share data and coordinate access
|
|
to the shared data. For example,
|
|
the producer warps in a warp-specialized algorithm
|
|
that copy input matrix tiles into shared memory
|
|
need to let the consumer MMA warp(s) know
|
|
that their MMA inputs are ready.
|
|
We contrast this with the `kernel::` layer API,
|
|
which schedules the collectives over *independent* tiles in the grid.
|
|
|
|
The Collective API includes both the "mainloop"
|
|
of matrix multiply-accumulate, and the epilogue.
|
|
This API is the composition point for optimizations
|
|
such as mainloop fusions and epilogue fusions.
|
|
It is responsible for implementing
|
|
the `k_tile` loop in the above triply nested loop pseudocode.
|
|
|
|
### Collective Mainloops
|
|
|
|
The `cutlass::gemm::collective::CollectiveMma` class
|
|
is the primary interface to the collective
|
|
matrix multiply-accumulate (MMA) mainloops.
|
|
"Mainloop" refers to the "main loop" over tiles --
|
|
the "cluster tile k" loop in the pseudocode
|
|
near the top of this document.
|
|
Any looping over multiple tiles that
|
|
the algorithm might need to do would happen here.
|
|
|
|
The `CollectiveMma` class is declared in the header
|
|
[cutlass/gemm/collective/collective_mma.hpp](/include/cutlass/gemm/collective/collective_mma.hpp).
|
|
|
|
```c++
|
|
namespace cutlass::gemm::collective {
|
|
|
|
template <
|
|
class DispatchPolicy,
|
|
class TileShape,
|
|
class ElementA,
|
|
class StrideA,
|
|
class ElementB,
|
|
class StrideB,
|
|
class TiledMma,
|
|
class GmemTiledCopyA,
|
|
class SmemLayoutAtomA,
|
|
class SmemCopyAtomA,
|
|
class TransformA,
|
|
class GmemTiledCopyB,
|
|
class SmemLayoutAtomB,
|
|
class SmemCopyAtomB,
|
|
class TransformB
|
|
>
|
|
struct CollectiveMma {
|
|
static_assert(sizeof(ElementA) == 0, "Could not find a mainloop specialization.");
|
|
};
|
|
|
|
} // namespace cutlass::gemm::collective
|
|
```
|
|
|
|
- `DispatchPolicy` is the most important type for a collective, and is
|
|
[covered in more detail below](#collective-dispatch-policies).
|
|
|
|
- `StrideA` and `StrideB` are instances of type `cute::Stride` that represent the global memory layout of A and B tensors. These strides are required to be rank-3, representing the modes `[outer, inner, batch]`. Each of the 3 ranks can be a multi-modal hierarchical stride; this would apply if implementing a tensor contraction.
|
|
|
|
- `TiledMma` is an instance of `cute::TiledMma`.
|
|
|
|
- `GmemTiledCopyA` and `GmemTiledCopyB` are instances of `cute::TiledCopy` types. Both tiled operation types are [covered in more detail below](#tiled-mma-and-copy).
|
|
|
|
- `SmemLayoutAtomA` and `SmemLayoutAtomB` are instances of type `cute::Layout` and represent the smallest
|
|
layout that will get tiled over the entire collective's shared memory. This layout does _not_ include the
|
|
pipeline mode, and therefore, both are expected to be rank 2 layouts of shape [`outer`, `inner`].
|
|
|
|
- `SmemCopyAtomA` and `SmemCopyAtomB` are `Copy_Atom`s to be used for moving data from shared memory
|
|
into register memory.
|
|
|
|
Notice that CUTLASS 3.0 mainloops do not accept a dedicated accumulator element type.
|
|
We obtain the accumulator type from the `typename TiledMma::ValTypeC`. Note also that
|
|
top level API's `ElementA` and `ElementB` can differ from those of the MMA facing
|
|
`typename TiledMma::ValTypeA` and `typename TiledMma::ValTypeB`, allowing TMA or user
|
|
supplied transform operations to perform type conversions.
|
|
|
|
### Collective Dispatch Policies
|
|
|
|
`CollectiveMma` implementations are not generic.
|
|
Instead, they must be specialized for each algorithm and GPU architecture.
|
|
Users can dispatch to a `CollectiveMma` specialization
|
|
by picking template arguments matching that specialization.
|
|
CUTLASS 3.0 adopts a tag-based dispatch policy type to specialize
|
|
mainloop implementations and add tuning knobs to them.
|
|
|
|
Below is an example of one of the dispatch policies that is used to dispatch to a Hopper TMA
|
|
warp-specialized mainloop implementation:
|
|
|
|
```c++
|
|
// n-buffer in smem (Hopper TMA),
|
|
// pipelined with Hopper GMMA and TMA,
|
|
// warp-specialized dynamic schedule
|
|
template<
|
|
int Stages_,
|
|
class ClusterShape_ = Shape<_1,_1,_1>,
|
|
class KernelSchedule = KernelTmaWarpSpecializedCooperative
|
|
>
|
|
struct MainloopSm90TmaGmmaWarpSpecialized {
|
|
constexpr static int Stages = Stages_;
|
|
using ClusterShape = ClusterShape_;
|
|
using ArchTag = arch::Sm90;
|
|
using Schedule = KernelSchedule;
|
|
};
|
|
```
|
|
|
|
The `Stages_` template parameter lets the user freely vary the number of pipeline stages,
|
|
while the `ClusterShape_` type allows for parameterization over the shape of the threadblock
|
|
cluster over which TMA multicast will take place.
|
|
|
|
The collective dispatch policy is also the primary point of composing various kernel schedules
|
|
freely with any mainloop. Each mainloop policy either prescribes a `Schedule` with which
|
|
it needs to be run, or exposes a template API that lets the user pick a subset of the following schedules:
|
|
|
|
```c++
|
|
struct KernelCpAsyncWarpSpecialized { };
|
|
struct KernelCpAsyncWarpSpecializedPingpong { };
|
|
struct KernelCpAsyncWarpSpecializedCooperative { };
|
|
struct KernelTma { };
|
|
struct KernelTmaWarpSpecialized { };
|
|
struct KernelTmaWarpSpecializedPingpong { };
|
|
struct KernelTmaWarpSpecializedCooperative { };
|
|
```
|
|
|
|
- A single kernel schedule can support multiple mainloop implementations. For example,
|
|
`KernelMultistage` can be composed with many different mainloop implementations across GPU
|
|
architectures such as `MainloopSm70TwoStage`, `MainloopSm80CpAsyncUnpredicated`, and many more.
|
|
|
|
- A single mainloop can be composed with multiple
|
|
possible kernel schedules. For example, the `MainloopSm90TmaGmmaWarpSpecialized` can be
|
|
composed with any of the `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative`
|
|
kernel schedules.
|
|
|
|
As [discussed in the CUTLASS 3.0 design documentation](cutlass_3x_design.md), adopting tag
|
|
dispatch policies for our core vocabulary types allows us to maintain a single type name for
|
|
all operations that conceptually belong to the same class. This design has the following benefits.
|
|
|
|
- It *avoids code duplication* in cases where mainloops can be composed with multiple kernels or vice versa.
|
|
- It *makes writing generic code easier*, as the primary type name `CollectiveMma` does not change across any implementation.
|
|
- It *provides a clear, singular extension point* for users to plug in new, custom mainloops implementations specialized on their own dispatch policies.
|
|
|
|
### Collective Builder for `CollectiveMma`s
|
|
|
|
The primary `CollectiveMma` is intended to be an expert user interface that allows full control over
|
|
all the properties of the collective's GPU micro-kernel. However, often a user just wants an
|
|
off-the-shelf GEMM mainloop implementation parameterized on simple configuration parameters. CUTLASS 3.0
|
|
provides [`cutlass::gemm::collective::CollectiveBuilder`](/include/cutlass/gemm/collective/collective_builder.hpp) for such scenarios.
|
|
|
|
```c++
|
|
namespace cutlass::gemm::collective {
|
|
template <
|
|
class ArchTag,
|
|
class OpClass,
|
|
class ElementA,
|
|
class GmemLayoutA,
|
|
int AlignmentA,
|
|
class ElementB,
|
|
class GmemLayoutB,
|
|
int AlignmentB,
|
|
class ElementAccumulator,
|
|
class TileShape_MNK,
|
|
class ClusterShape_MNK,
|
|
class StageCountType,
|
|
class KernelScheduleType,
|
|
class Enable = void
|
|
>
|
|
struct CollectiveBuilder {
|
|
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
|
|
};
|
|
} // namespace cutlass::gemm::collective
|
|
```
|
|
|
|
`CollectiveBuilder` accepts CUTLASS 2.x equivalent input template arguments, and attempts to build
|
|
the best performing `CollectiveMma` from the given parameters.
|
|
|
|
- `ArchTag` is one of the SM architectures tags from `cutlass::arch::Sm*`.
|
|
- `OpClass` is one of the operator class tags from `cutlass::arch::Sm*`.
|
|
- `ElementA` and `ElementB` are the logical value types of the A resp. B tensors.
|
|
- `ElementAccumulator` is the accumulator type to be used in the instruction.
|
|
- `GmemLayoutA` and `GmemLayoutB` are CUTLASS 2.x layout tags, `layout::RowMajor` or `layout::ColumnMajor`.
|
|
- `AlignmentA` and `AlignmentB` are global memory alignments of A and B tensors in terms of element count.
|
|
- `TileShape_MNK` is an instance of `cute::Shape` that is rank-3, representing the MxNxK collective tile shape.
|
|
- `ClusterShape_MNK` is an instance of `cute::Shape` that is rank-3, representing the MxNxK threadblock cluster tile shape.
|
|
- `StageCountType` is either `collective::StageCountAuto` or an instance of `collective::StageCount<N>`.
|
|
- `KernelScheduleType` is either `collective::KernelScheduleAuto` or one of the specific kernel schedule tags discussed in the [dispatch policy section](#collective-dispatch-policies) above.
|
|
|
|
`StageCountAuto` allows the collective builder to compute the size of a single stage's size in shared memory
|
|
and maximize the shared memory usage assuming 1 threadblock / multiprocessor occupancy.
|
|
|
|
`KernelScheduleAuto` allows the collective builder to pick the best kernel schedule available for the
|
|
given set of parameters, or let's the user override this with a specific kernel schedule type.
|
|
|
|
Note that collective builders are still in beta, and their functionality
|
|
does not map onto the full design space that the primary expert `CollectiveMma` API
|
|
allows for. We expect their supported mainloop types to expand in future releases, but
|
|
with 3.0, only SM90 tensorop kernels are supported through the builder API. The builder API
|
|
may also change in the future as we adopt user feedback.
|
|
|
|
If the builder is able to provide a collective mainloop type for the given set of parameters,
|
|
it will be aliased within as `CollectiveOp`. For more information on how to
|
|
parameterize kernels conveniently with the collective builder, please see example [49_hopper_gemm_with_collective_builder](/examples/49_hopper_gemm_with_collective_builder).
|
|
|
|
### Epilogue
|
|
|
|
The collective epilogue implements element-wise operations
|
|
involving the output matrix. Users can provide a custom
|
|
epilogue, or use one of the standard epilogues.
|
|
These live in the directory
|
|
[include/cutlass/epilogue/collective/](/include/cutlass/epilogue/collective/),
|
|
and include classes like
|
|
`cutlass::epilogue::collective::DefaultEpilogue`
|
|
and
|
|
`cutlass::epilogue::collective::Epilogue`.
|
|
CUTLASS's provided collective epilogues
|
|
do not live under `include/cutlass/gemm`
|
|
or in the `cutlass::gemm` namespace,
|
|
because they can be used for computations
|
|
other than GEMM.
|
|
|
|
## Kernel API
|
|
|
|
The kernel is "a collection of all clusters in the grid."
|
|
The kernel layer schedules have four main responsibilities.
|
|
|
|
- Ordering the execution of collectives within the kernel, performing any synchronization between that may be necessary
|
|
- Marshalling the threads of a warp specialized schedules into their respective roles
|
|
- Performing any necessary grid swizzling logic
|
|
- Tiling the input tensors with the threadblock cluster value tile before invoking the collectives on them
|
|
|
|
The Kernel API is the entry point for a grid of thread blocks
|
|
that may or may not be organized in a cluster.
|
|
It is the composition point for fusing back-to-back GEMMs,
|
|
epilogues, and/or other operations.
|
|
|
|
The entry point API for CUTLASS 3.0 kernel is the class
|
|
`cutlass::gemm::kernel::GemmUniversal`, found in the header file
|
|
[include/cutlass/gemm/kernel/gemm_universal.hpp](/include/cutlass/gemm/kernel/gemm_universal.hpp).
|
|
`GemmUniversal` is a stateless universal device kernel
|
|
that implements GEMM as the composition of two parts:
|
|
|
|
* a collective mainloop, and
|
|
* a collective epilogue
|
|
|
|
```cpp
|
|
namespace cutlass::gemm::kernel {
|
|
/*
|
|
* Stateless universal device GEMM kernel type that treats GEMM as
|
|
* a composition of a collective mainloop and a collective epilogue.
|
|
*
|
|
* Supports both the 2.x and 3.x APIs based on whether the first type is
|
|
* a cute::tuple<> or not.
|
|
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
|
|
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
|
|
*
|
|
* In the following declaration, the name preceding the 'Or' refers to
|
|
* 3.x API type argument order, and the name succeeding the 'Or' refers to
|
|
* 2.x API type argument order. Template arguments without two names
|
|
* belong to the 3.x API only.
|
|
**/
|
|
template <
|
|
class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
|
|
class CollectiveMainloopOrEpilogue_,
|
|
class CollectiveEpilogueOrThreadblockSwizzle_,
|
|
class TileScheduler_ = void,
|
|
class Enable = void
|
|
>
|
|
class GemmUniversal;
|
|
} // namespace cutlass::gemm::kernel
|
|
```
|
|
|
|
*Stateless* means that the caller --
|
|
for example, the Device API described above --
|
|
manages the kernel's state.
|
|
The kernel just takes input and output parameters (`Params`).
|
|
|
|
*Universal* means that `GemmUniversal` works
|
|
for both CUTLASS 3.0 and 2.x interfaces
|
|
and across a broad range of kernel schedules.
|
|
If `GemmUniversal`'s first template argument is a `cute::Shape`,
|
|
then `GemmUniversal` assumes that the remaining template arguments
|
|
implement the 3.0 APIs. Otherwise, `GemmUniversal` assumes that
|
|
the remaining template arguments implement the 2.x APIs.
|
|
Starting with CUTLASS 3.0, the problem shape has been promoted
|
|
to a top-level template API for the GEMM kernel.
|
|
This supports fully static GEMM instantiations
|
|
where the user expects to know some or all
|
|
of the problem shapes at compile time
|
|
in order to extract even more performance.
|
|
|
|
The *collective mainloop* implements MMA on local tiles.
|
|
The *collective epilogue* addresses any operations after the MMA,
|
|
such as applying the `beta * C` part of `C := beta * C + alpha * A * B`.
|
|
We will explain *collective* in more detail below.
|
|
|
|
Specializations of `kernel::GemmUniversal` for 3.0 APIs live in
|
|
any of various `gemm_*.hpp` files in the directory
|
|
[include/cutlass/gemm/kernel/](/include/cutlass/gemm/kernel/).
|
|
Specializations for 2.x APIs can be found in the header file
|
|
[include/cutlass/gemm/kernel/gemm_universal.h](/include/cutlass/gemm/kernel/gemm_universal.h).
|
|
|
|
CUTLASS 3.x implements various embodiments of `kernel::GemmUniversal`.
|
|
Each kernel layer schedule is specialized
|
|
for a GEMM scheduling algorithm and GPU architecture.
|
|
Specializations of `kernel::GemmUniversal` for 3.0 APIs live in
|
|
any of various `include/cutlass/gemm/kernel/{arch_tag}*.hpp` files in the directory
|
|
[include/cutlass/gemm/kernel/](/include/cutlass/gemm/kernel/).
|
|
Which specialization to dispatch to is decided through the dispatch policy's `Schedule` type.
|
|
|
|
For example, the header file
|
|
[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp)
|
|
has a specialization of `kernel::GemmUniversal` for Hopper
|
|
that uses a warp-specialized mainloop with a persistent scheduling algorithm,
|
|
while the header file
|
|
[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp)
|
|
has a specialization of `GemmUniversal` for Hopper
|
|
that uses a warp-specialized but non-persistent algorithm.
|
|
|
|
To support composition between supported kernel schedules and mainloop dispatch policies without having to
|
|
duplicate collective mainloop implementations, GEMM kernel layer schedules can be composed with
|
|
any mainloop that specifies their corresponding kernel schedule as their `Schedule` type in the policy.
|
|
This is discussed in detail in the [collective dispatch policy section](#collective-dispatch-policies) above.
|
|
|
|
```c++
|
|
// An example of the SM90 KernelMultistage kernel's
|
|
// specialization logic that allows it to be composed
|
|
// with many mainloops such as `MainloopSm80CpAsync`
|
|
// and `MainloopSm70TwoStage`.
|
|
template <
|
|
class ProblemShape_,
|
|
class CollectiveMainloop_,
|
|
class CollectiveEpilogue_,
|
|
class TileScheduler_
|
|
>
|
|
class GemmUniversal<
|
|
ProblemShape_,
|
|
CollectiveMainloop_,
|
|
CollectiveEpilogue_,
|
|
TileScheduler_,
|
|
std::enable_if_t<std::is_base_of_v<KernelMultistage, typename CollectiveMainloop_::DispatchPolicy::Schedule>>>
|
|
```
|
|
|
|
## Device API
|
|
|
|
The Device API is a universal, kernel-agnostic host interface
|
|
for kernel launch and managing the lifetime of
|
|
reusable host-side parameters.
|
|
|
|
This API is how users' host-side .cu code
|
|
invokes CUTLASS's single-GPU GEMM kernels.
|
|
It serves the same purpose as cuBLAS and behaves similarly.
|
|
|
|
The entry point for the Device GEMM API is the class
|
|
`cutlass::gemm::device::GemmUniversalAdapter`.
|
|
This class lives in the header file
|
|
[include/cutlass/gemm/device/gemm_universal_adapter.h](/include/cutlass/gemm/device/gemm_universal_adapter.h).
|
|
`GemmUniversalAdapter` is a stateful, reusable handle,
|
|
which is parameterized on the `cutlass::gemm::kernel` type.
|
|
|
|
```c++
|
|
/*!
|
|
GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel
|
|
of type cutlass::gemm::kernel::*
|
|
|
|
It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs
|
|
to create it from the host facing arguments. For power users, new static methods
|
|
are exposed in 3.x APIs that bypass the stateful methods or args->params lowering.
|
|
|
|
It supports kernel types that implement both the 2.x and 3.0 APIs,
|
|
however, this is done by specializing the implementation of GemmUniversalAdapter
|
|
on the two kernel API types, and thus, GemmUniversalAdapter's behavior might
|
|
differ between the two specializations.
|
|
*/
|
|
template <class GemmKernel_, class Enable = void>
|
|
class GemmUniversalAdapter;
|
|
```
|
|
|
|
*Stateful* means that the handle instance contains state
|
|
that the kernel needs to run.
|
|
This means that the user must initialize the handle first,
|
|
then use the initialized handle instance to run the kernel.
|
|
Statefulness also means that the handle can manage the lifetime
|
|
of the kernel's `Params` -- the parameters of the kernel itself.
|
|
An important duty of `GemmUniversalAdapter`
|
|
is to map from the user's `Arguments` --
|
|
what the user sees as the kernel's parameters --
|
|
to the `Params` that the kernel actually sees.
|
|
For power users, the class exposes new static methods
|
|
in 3.0 APIs that can bypass stateful methods
|
|
or go directly to `Params` without intermediate `Arguments`.
|
|
|
|
*Reusable* means that the handle instance can be used
|
|
to call the kernel multiple times with different arguments
|
|
(e.g., different matrices).
|
|
Reusing the handle may be more efficient than just
|
|
creating a new handle for each kernel invocation.
|
|
|
|
*Parameterized on the kernel type* means that
|
|
the `GemmUniversalAdapter` class' behavior
|
|
depends on the GEMM kernel type (see the next section).
|
|
Specifically, `GemmUniversalAdapter` has a template parameter
|
|
`GemmKernel`, which is the GEMM kernel type.
|
|
Valid template arguments for `GemmKernel` are
|
|
|
|
* `cutlass::gemm::kernel::GemmUniversal`,
|
|
implementing CUTLASS 3.x API kernels;
|
|
* `cutlass::gemm::kernel::GemmUniversal`,
|
|
implementing CUTLASS 2.x API kernels; or
|
|
* Any valid CUTLASS 2.x `kernel` layer GEMM that
|
|
was previously composable with the `device::GemmUniversalAdapter`.
|
|
|
|
`GemmUniversalAdapter` presents a single
|
|
host-side interface to both 3.0 and 2.x kernels.
|
|
CUTLASS accomplishes this by
|
|
specializing `GemmUniversalAdapter`'s implementation
|
|
on either the 2.x API implementing kernel layer GEMMs, or on the 3.x API
|
|
implementing kernel layer GEMMs. The metafunction [`cutlass::gemm::detail::IsCutlass3GemmKernel`](cutlass_3x_backwards_compatibility.md#kernel-api-design-differences)
|
|
is what `GemmUniversalAdapter` uses to distinguish between 2.x and 3.x kernels.
|
|
|
|
`GemmUniversalAdapter` sets up and launches the kernel, using the
|
|
CUDA extended launch API for threadblock cluster support if required.
|
|
Note, `GemmUniversalAdapter` does *not* specify the grid shape.
|
|
The kernel controls the grid shape
|
|
and other kernel-specific launch parameters.
|
|
This makes it possible for all 3.0 kernels
|
|
to use the same kernel launch code,
|
|
thus factoring out kernel launch from the actual kernel.
|
|
|
|
## Tiled MMA and Copy
|
|
|
|
The Tiled MMA or Copy are tilings of MMA atoms resp. Copy atoms
|
|
across threads and data, with possible permutations applied to the
|
|
resulting tiling. This layer is most analogous to the warp level
|
|
tiling of MMA instructions in CUTLASS 2.x. However, it views the tiling
|
|
from the perspective of all threads participating in the operation
|
|
and generalizes the concept to copy operations as well. The purpose
|
|
of this layer is to build composable GPU micro-kernels out of a plethora
|
|
of hardware accelerated math and data movement operations, each with their
|
|
unit layouts in threads and data. The tiled MMA and Copy types present
|
|
all these various hardware accelerated CuTe Atoms with a single, consistent
|
|
API.
|
|
|
|
The resulting tiled operation acts as a single MMA or copy operation
|
|
that users can invoke in the "inner" loop
|
|
of the three-nested-loops pseudocode
|
|
at the top of this document using `cute::gemm()` or `cute::copy()`.
|
|
|
|
We call this API "tiled" because it constructs
|
|
larger operations out of the Atoms provided by CuTe,
|
|
as if fitting together individual tiles
|
|
to build a reusable component of a mosaic.
|
|
For example, CuTe might provide an MMA Atom
|
|
that users can call on a single warp,
|
|
for fixed M, N, and K dimensions.
|
|
CUTLASS can then use CuTe operations like `make_tiled_mma`
|
|
to turn this Atom into an operation
|
|
that works on an entire thread block,
|
|
for larger M, N, and K dimensions.
|
|
|
|
## Atom API
|
|
|
|
An "Atom" is the smallest collection of threads and data
|
|
that must participate in the execution of a hardware-accelerated
|
|
math or copy operation.
|
|
|
|
An Atom is "atomic" (indivisible) not in the sense of
|
|
concurrent memory operations like `atomicAdd`
|
|
(which are "indivisible in time (causality)"),
|
|
but in the sense of indivisibility in "space" --
|
|
the number of values and the groups of parallel workers
|
|
that must participate in the operation together.
|
|
|
|
An Atom uses CuTe Layouts to express the required
|
|
dimensions and strides of its input and output arrays.
|
|
Generally these are fixed at compile time.
|
|
|
|
The Atom API wraps calls to actual hardware instructions
|
|
that accelerate MMA or copy operations.
|
|
Users can ask for GPU architecture-specific implementations,
|
|
or just pick generic implementations and rely on
|
|
whatever GPU architectures were enabled.
|
|
|
|
For more information about Atoms,
|
|
please refer to CuTe's tutorial, e.g., the sections on
|
|
|
|
* [algorithms](./cute/04_algorithms.md) like `gemm` and `copy`,
|
|
|
|
* [MMA Atoms](./cute/0t_mma_atom.md#cute-mma-atoms), and
|
|
|
|
* [a GEMM example](./cute/0x_gemm_tutorial.md).
|
|
|
|
# Copyright
|
|
|
|
Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
```
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are met:
|
|
|
|
1. Redistributions of source code must retain the above copyright notice, this
|
|
list of conditions and the following disclaimer.
|
|
|
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
this list of conditions and the following disclaimer in the documentation
|
|
and/or other materials provided with the distribution.
|
|
|
|
3. Neither the name of the copyright holder nor the names of its
|
|
contributors may be used to endorse or promote products derived from
|
|
this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
```
|