cutlass/media/docs/cute/0x_gemm_tutorial.md

539 lines
30 KiB
Markdown

# CuTe dense matrix-matrix multiply tutorial
In this section, we review
[these examples](../../../examples/cute/tutorial/),
which demonstrate a few self-contained, single-file dense matrix-matrix multiply implementations using only CuTe.
## `sgemm_1.cu`
The simplest of the tutorial examples covers the basics of partitioning the global memory into tiles across the CTAs (also called threadblocks in CUDA), partitioning the data tiles across the threads of each CTA, and writing a mainloop using `cute::copy` and `cute::gemm`.
### High-level interface
We'll start with the kernel entry point `gemm_device` at the top of the file.
```c++
template <class ProblemShape, class CtaTiler,
class TA, class AStride, class ASmemLayout, class AThreadLayout,
class TB, class BStride, class BSmemLayout, class BThreadLayout,
class TC, class CStride, class CSmemLayout, class CThreadLayout,
class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,
TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,
TC * C, CStride dC, CSmemLayout , CThreadLayout tC,
Alpha alpha, Beta beta)
```
There are many template parameters, let's quickly review them and then go into more depth on their uses.
* `ProblemShape`. The MxNxK problem shape of this matrix multiply.
* `CtaTiler`. A CuTe [tiler concept](./02_layout_algebra.md#composition-tilers) that determines how to extract a tile of data from the problem shape.
* `TA const* A`, `TB const* B`, `TC* C`. The types and pointers to the A, B, and C data, respectively.
* `AStride`, `BStride`, `CStride`. The layout strides corresponding to the `ProblemShape` for each A, B, and C.
* `ASmemLayout`, `BSmemLayout`, `CSmemLayout`. The layouts, if needed, of shared memory to use for staging A-data, B-data, and C-data within each CTA.
* `AThreadLayout`, `BThreadLayout`, `CThreadLayout`. The layouts of threads to be used in partitioning each stage.
* `Alpha alpha`, `Beta beta`. The types and values of the scalar constants to compute GEMM: `C = alpha * A * B + beta * C`.
### The Full Tensors: Shapes, Strides, and Data
Most GEMM interfaces list the matrices' dimensions
in the order M, N, K. CuTe also uses this convention, but packages them
into a single `IntTuple`. In this example, they are dynamic values
defined at the top of the `gemm_nt` and `gemm_tn` host functions
that invoke the device kernel.
```cpp
// Define shapes (dynamic)
auto M = int(m);
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
```
Inside the kernel, the problem shape is checked against the preconditions and then used to construct each of the full matrices.
```cpp
// Preconditions
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
// Represent the full tensors
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
```
The appropriate modes of the `Shape` are selected to construct each of the tensors. The preconditions make sure that for every integer in the `Shape` there is a corresponding integer in the associated `Stride`.
Note that the comment after B says `(N,K)` rather than `(K,N)`.
This means that B is treated as an NxK matrix instead of a KxN matrix as is typical within BLAS and most other matrix-matrix multiplications.
CuTe follows the convention that the semantics of matrix modes is
`(M,K)` for `A`, `(N,K)` for `B`, and `(M,N)` for `C`, which we try to record in comments everywhere.
For each of the `(M,K)`, `(N,K)`, and `(M,N)` tensors, the `gemm_nt` and `gemm_tn` construct the strides those tensors will use. In `gemm_nt` the strides are defined as
```cpp
// Define NT strides (mixed)
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
```
and in `gemm_tn` the strides are defined as
```cpp
// Define TN strides (mixed)
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
```
#### Aside: M-major, N-major, K-major
We've found that the BLAS convention of using "non-transposed" (N) and "transposed" (T) flags in conjunction with the mode conventions of `MxK * KxN` to confuse the core issue of "what layout does this matrix use" and "in which mode does my matrix have a stride-1?". Indeed, the answer to those questions can always be found by inspecting the CuTe `Layout`.
Instead of row-major or column-major (or Transposed
and Not-Transposed), we have found it much more convenient to say that a matrix is "M-major" if it is stride-1 in the M-mode, "N-major" if it is stride-1 in the N-mode, or "K-major" if it is stride-1 in the K-mode. Furthermore, knowing that matrix multiply always performs a reduction in the K-mode, it is very convenient from a software perspective to always have the K-mode in the same place and adopt the mode convention `MxK * NxK`. Implementations will always reduce over the second mode (the K mode) of both input matrices and leads to cases where implementations can treat both input matrices the same way.
How do we translate this into the BLAS user's experience?
| BLAS | A Majorness | A Layout | B Majorness | B Layout |
| --- | --- | --- | --- | --- |
| NT | M-major | `(M,K):(1,ldA)` | N-major | `(N,K):(1,ldA)` |
| TN | K-major | `(M,K):(ldA,1)` | K-major | `(N,K):(ldB,1)` |
| NN | M-major | `(M,K):(1,ldA)` | K-major | `(N,K):(ldB,1)` |
| TT | K-major | `(M,K):(ldA,1)` | N-major | `(N,K):(1,ldA)` |
Regardless, we'll still use the BLAS "NT" and "TN" notations for high-level descriptions of kernels when it's appropriate.
### CTA Partitioning
Now that we have the representations of the full matrices, it's time to tile them and split up the work!
At the highest level, the work is distributed across CTAs. In principle, each CTA's tile could come from the input tensors in many different ways. Many [CuTe `Tiler`s](./02_layout_algebra.md#composition-tilers) could be used to tile the data, but for these cases it is sufficient to simply use the shape of the desired CTA tile.
```cpp
// Define CTA tile sizes (static)
auto bM = Int<128>{};
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
```
Once the tiler has been defined, we can use it to tile and partition the tensors across the CTAs.
```cpp
// Get the appropriate blocks for this threadblock
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
```
First, the CTA coordinate is created.
* The `m`-coordinate of this tile is given by `blockIdx.x`.
* The `n`-coordinate of this tile is given by `blockIdx.y`.
* The `k`-coordinate of this tile is unspecified -- we want all of the tiles in `K` so the coordinate is `_`, the `Underscore` value, to keep that mode.
Then, `local_tile` is used to remove the modes of the tiler and coord corresponding to the `X`s. That is, the `Step<_1, X,_1>` is just shorthand for
```cpp
// Use select<0,2> to use only the M- and K-modes of the tiler and coord
Tensor gA = local_tile(mA, select<0,2>(cta_tiler), select<0,2>(cta_coord));
```
This `local_tile` is simply shorthand for
1. apply the tiler via [`zipped_divide`](./02_layout_algebra.md#zipped-tiled-flat-divides)
```cpp
// ((BLK_M,BLK_K),(m,k))
Tensor gA_mk = zipped_divide(gA, select<0,2>(cta_tiler));
```
2. apply the coord to the second mode, the "Rest" mode, to extract out the correct tiles for this CTA.
```cpp
// (BLK_M,BLK_N,k)
Tensor gA = gA_mk(make_coord(_,_), select<0,2>(cta_coord));
```
Because the projections of the tiler and coord are symmetric and the two steps (apply a tiler and then slice into the rest-mode to produce a partition) are so common, they are wrapped together into the projective `local_tile` interface.
For tensor `A`, we are left with a rank-3 tensor of shape `(BLK_M,BLK_K,k)`. The first two modes are precisely the modes of the CTA tile and the last mode indexes over all of the tiles that will be reduced by this CTA. In the mainloop section below, this mode is iterated over via the `k_tile` loop.
### SMEM tensors
The shared memory layouts that are used to hold the tiles of data for A and B are also passed in as the parameters `ASmemLayout sA_layout` and `BSmemLayout sB_layout`.
These are defined in `gemm_nt` as
```c++
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
```
which produces simple M-major and N-major layouts. In `gemm_tn` these are
```cpp
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM,bK), LayoutRight{}); // (m,k) -> smem_idx; k-major
auto sB = make_layout(make_shape(bN,bK), LayoutRight{}); // (n,k) -> smem_idx; k-major
```
which produces simple K-major layouts.
As is evident, these smem layouts can be almost anything. Inside the kernel, they are checked for only two properties: the shared memory layouts are static and they are the same top-level shape as the `CtaTiler`.
```cpp
// Preconditions
static_assert(is_static<ASmemLayout>::value);
static_assert(is_static<BSmemLayout>::value);
static_assert(is_static<CSmemLayout>::value);
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
```
Use of static layouts has a few advantages.
* Static layouts let us statically allocate shared memory as shown below.
* Static layouts are often more efficient and allow CuTe to dispatch to optimized implementations.
* Static layouts makes it easier to prove correctness of the algorithm and provide checks like the above -- the smem layout sizes are the same as the CTA tile sizes.
As stated, the shared memory layouts can be anything that satisfy those conditions. Optimizing kernels like these is often performed by finding a good shared memory layout that provides good access patterns for both the writes to and the reads from shared memory. This includes the ability to vectorize reads and writes as well as avoid shared memory bank conflicts.
Wih the static smem layouts, the `gemm_device` kernel can allocate the required shared memory and create the smem `Tensor`s.
```cpp
// Shared memory buffers
__shared__ TA smemA[cosize_v<ABlockLayout>];
__shared__ TB smemB[cosize_v<BBlockLayout>];
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
```
Note how the shared memory allocation depends only on the data type and the layout. What's a `cosize`? Because a `Layout` is a function, we can speak of its domain and codomain. The `size` of a layout is the size of its domain and the `cosize` of a layout is the size of its codomain. If we want to allocate an array for which all the offsets produced by a layout are valid, then we can use the `cosize` of the layout as the length of the array (in units of elements).
### Copy partitioning
The kernel now has tiles of global memory by applying the `CtaTiler` to the full tensors and it also has tiles of shared memory by allocating appropriately. We now want to create an efficient way to copy one tile of global memory to our tile of shared memory. A trivial way to do this would be to use a single thread and copy each element.
```cpp
if (thread0()) {
Tensor gA0 = gA(_,_,0); // (BLK_M,BLK_K), the 0th tile
for (int i = 0; i < size(sA); ++i) {
sA(i) = gA0(i);
}
}
```
This would work, but we have lots of threads to use inside this CTA, so let's use them!
If we partition the two tiles of data across the threads in the CTA, then each thread can copy its own subtensor of data. There are lots of ways this partitioning could occur, however.
The `gemm_nt` function defines two layouts of *threads* as
```c++
// Define thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{},Int<8>{})); // (m,k) -> thr_idx
auto tB = make_layout(make_shape(Int<32>{},Int<8>{})); // (n,k) -> thr_idx
```
and the `gemm_tn` functions defines two layouts of *threads* as
```c++
// Define thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{},Int<8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major
auto tB = make_layout(make_shape(Int<32>{},Int<8>{}), LayoutRight{}); // (n,k) -> thr_idx; k-major
```
Both cases happen to use 32x8 threads, which will be used to partition a 128x8 tile of gmem and smem data into a 4x1 subtensor for each thread. The only difference here is that `gemm_nt` uses M-major and N-major threads to match the order of data in global memory and `gemm_tn` uses K-major threads to match the order of data in global memory.
Again, the conditions on the thread layouts are checked inside the kernel.
```cpp
static_assert(is_static<AThreadLayout>::value);
static_assert(is_static<BThreadLayout>::value);
CUTE_STATIC_ASSERT_V(size(tA) == size(tB)); // NumThreads
CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{}); // BLK_M / THR_M
CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{}); // BLK_K / THR_K
CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{}); // BLK_N / THR_N
CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{}); // BLK_K / THR_K
```
These thread layouts are then used to partition the global memory tensors data and shared memory tensors
```cpp
Tensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)
Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)
Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)
Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)
CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA)); // THR_M
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // THR_K
CUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB)); // THR_N
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // THR_K
```
where `local_partition` is a lot like `local_tile`, except the coordinate slices into the tile-mode (the first mode) of the `zipped_divide` rather than the rest-mode (the second mode). That is, each thread gets one element of data assigned to it per thread tile and that thread tile is repeated to cover the entire data tile.
The naming convention `tAsA` is pretty typical across CuTe and CUTLASS. This is read as "Partitioning pattern `tA` applied to tensor `sA`". In the next section, we'll see a different partitioner applied to `sA` to produce `tCsA`. By applying the same partitioning pattern, `tA`, to tensors `sA` and `gA`, we preserve the *logical consistency* of those tensors (checked by the assertions above) where logical elements between the two tensors correspond despite any differences in their data layouts. When used in `cute::copy`, for example, this naming convention let's us lexically verify that the two tensors are using the same partitioning pattern.
With the data partitioned across the threads, *every thread* can now participate in the copy by writing
```cpp
copy(tAgA(_,_,0), tAsA);
```
because every thread owns a different subtensor of the tile that will be copied.
### Math partitioning
The kernel now has tiles of shared memory copied in from global memory. We now want to create an efficient way to compute and accumulate the matrix product on that tile of shared memory. A trivial way to do this would be to use a single thread and compute directly.
```cpp
if (thread0()) {
for (int m = 0; m < size<0>(gC); ++m) {
for (int n = 0; n < size<1>(gC); ++n) {
for (int k = 0; k < size<1>(sA); ++k) {
gC(m,n) += sA(m,k) * sB(n,k);
}
}
}
}
```
This would work, but we have lots of threads to use inside this CTA, so let's use them!
If we partition the output tile `gC` across the threads in the CTA, then each thread can compute its own subtensor. There are lots of ways this partitioning could occur, however.
The `gemm_nt` and `gemm_tn` functions define one more layout of *threads*:
```cpp
// Define thread layouts (static)
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx; m-major
```
This is a m-major 16x16 layout of threads which will be used to partition a 128x128 tile of `C`-data, resulting in each thread computing its own 8x8 subtensor of `gC`.
Again, the conditions on the thread layouts are checked inside the kernel.
```cpp
static_assert(is_static<CThreadLayout>::value);
CUTE_STATIC_ASSERT_V(size(tC) == size(tA)); // NumThreads
CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{}); // BLK_M / THR_M
CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{}); // BLK_N / THR_N
```
These thread layouts are then used to partition the tiles of data in global memory and shared memory
```cpp
// Partition sA (M,K) by the rows of tC
Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
// Partition sB (N,K) by the cols of tC
Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)
// Partition gC (M,N) by the tile of tC
Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)
// Allocate the accumulators -- same shape/layout as the partitioned data
Tensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N)
CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC)); // THR_M
CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA)); // THR_M
CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC)); // THR_N
CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_N
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K
```
where we've used the same projection-style interface to avoid applying the `N`-mode of `tC` to the `(BLK_M,BLK_K)` shape of `sA` and avoid applying the `M`-mode of `tC` to the `(BLK_N,BLK_K)` shape of `sB`.
<p align="center">
<img src="../../images/cute/tC_partitioning.png" alt="tC_partitioning.png" height="300"/>
</p>
This diagram shows a `tC` layout, highlights two threads in green and blue, shows the projections of the `tC` layout, and finally highlights the subtensors within `sA`, `sB`, and `gC` that `tCsA`, `tCsB`, and `tCgC` represent.
With the data partitioned across the threads, *every thread* can now participate in the compute step by writing
```cpp
gemm(tCsA, tCsB, tCrC);
```
because every thread owns different subtensors of the data to be computed.
### Mainloop
The mainloop iterates over tiles of global memory, reads those tiles into shared memory, and then performs the matrix-multiply and accumulates into the accumulators.
```c++
// TUTORIAL: Example of a very simple compute mainloop
// copy(.) operates on the global and shared memory via the tA|tB partitioning
// gemm(.) operates on the shared and register memory via the tC partitioning
auto K_TILE_MAX = size<2>(tAgA);
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
{
// Copy gmem to smem with tA|tB thread-partitioned tensors
copy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K)
copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K)
cp_async_fence(); // Label the end of (potential) cp.async instructions
cp_async_wait<0>(); // Sync on all (potential) cp.async instructions
__syncthreads(); // Wait for all threads to write to smem
// Compute gemm on tC thread-partitioned smem
gemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)
__syncthreads(); // Wait for all threads to read from smem
}
```
We can see that `k_tile` iterates over each tile of data, the `cute::copy` is performed for the current `k_tile` using the `tA` and `tB` thread-partitioned tensors, and the `cute::gemm` is computed for that current `k_tile` using the `tC` thread-partitioned tensors. Synchronization is provided so that this kernel works on any architecture.
## `sgemm_2.cu`
An example that uses more complex `TiledMMA` and `TiledCopy` to perform partitioning in place of the `tA`, `tB`, and `tC` thread layouts. With this example, we try to emphasize that the shared memory layouts, the partitioning patterns, and the PTX instruction to use in each stage can be specified independently.
### TiledCopy
First, we can replace the `tA` partitioning and `tB` partitioning with `TiledCopy` partitioning, which provides for more complex partitioning patterns and checked dispatch to specific copy instructions.
As a first example, lets look at the `TiledCopy` that `gemm_nt` generates.
```cpp
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TA>{}, // Atom: Copy TAs as if they were uint128_t
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
Layout<Shape< _4,_1>>{}); // Val layout 4x1 m-major
print_latex(copyA);
```
The easiest way to see what this `TiledCopy` does is to look at the partition pattern in LaTeX.
<p align="center">
<img src="../../images/cute/TiledCopyA.png" alt="TiledCopyA.png" height="300"/>
</p>
On the left is the source-tensor partitioning and on the right is the destination-tensor partitioning. The partition patterns are the same for this case, but there exist PTX instructions which require different patterns in the source and destination. The diagram shows that each thread reads 4x1 `TA` elements and there are 32x8 threads. The `UniversalCopy<uint128_t>` forces the instruction to use a 128-bit copy instruction. If the partition (of `sA` or `gA` in this case) does not result in 4 `TA` elements that cannot be vectorized to a 128-bit load/store, then CuTe will statically fail with an error message to that effect.
To use the `TiledCopy`, the kernel writes
```cpp
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)
// Allocate registers same shape/layout as partitioned data
Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K)
```
which applies the source-tensor partitioning to `gA` via `partition_S` and applies the destination-tensor partitioning to `sA` via `partition_D`. The first mode, `CPY`, of the result tensors hold all of the elements that a single instruction will consume. In this case, that mode should have size-4 since there are four `TA=float` elements in a single 128-bit `uint128_t`.
Once the partition has been performed, we can execute the `copy` on the thread-partitioned tensors using the provided instruction in `copy_a`.
```cpp
cute::copy(copy_a, tAgA, tArA);
```
### TiledMMA
Next, we can replace the `tC` partitioning with `TiledMMA` partitioning, which provides for more complex partitioning patterns and checked dispatch to specific MMA instructions.
As a first example, lets look at the `TiledMMA` that `gemm_nt` generates.
```cpp
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 UniversalFMA
print_latex(mmaC);
```
The easiest way to see what this `TiledMMA` does is to look at the partition pattern in LaTeX.
<p align="center">
<img src="../../images/cute/TiledMmaC.png" alt="TiledMmaC.png" height="300"/>
</p>
On the left is the A-tensor partitioning, on the top is the B-tensor partitioning, and in the middle is the C-tensor partitioning.Because the `UniversalFMA` is a 1x1x1 MMA instruction, a 16x16x1 tiling of them results in a 16x16x1 `TiledMMA`. Other MMA instructions will have different threads involved and have different instruction sizes. In this case, all threads will read a single element from `A`, `B`, and `C` each.
To use the `TiledMMA`, the kernel writes
```cpp
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
```
which applies the A-tensor partitioning to `sA` via `partition_A`, applies the B-tensor partitioning to `sB` via `partition_B`, and applies the C-tensor partitioning to `gC` via `partition_C`. The first mode, `MMA`, of the result tensors hold all of the elements that a single instruction will consume. In this case, that mode should have size-1 since `UniversalFMA` is a 1x1x1 MMA, but in general the size of the first mode can vary and not even be the same across `tCsA`, `tCsB`, and `tCgC` depending on the MMA.
Once the partition has been performed, we can execute the `gemm` on the thread-partitioned tensors using the provided instruction in `mma`.
```cpp
cute::gemm(mma, tCsA, tCsB, tCrC);
```
### Other changes
In this version, we have also updated the shared memory layouts for `gemm_tn` from K-major to
```cpp
// Define the smem layouts (static)
auto sA = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
```
which produces M-major and N-major layouts, but they are padded to avoid shared memory bank conflicts. This simply improves the access pattern to and from shared memory and no other changes in the kernel are required.
## `sgemm_sm70.cu`
An example that uses an optimized mainloop for Volta SM70 architectures that pipelines shared memory and register memory.
## `sgemm_sm80.cu`
An example that uses an optimized mainloop for Ampere SM80 architectures that explicitly pipelines shared memory using asynchronous reads from global memory.
## Next steps
All of the above examples assume that the CTA tile size divides the problem size so that global memory loads do no need to be predicated. The
[predication section of the tutorial](./0y_predication.md)
explains what to do if a matrix tiling
doesn't perfectly divide the matrix.
## GETT as GEMM
"GETT" here stands for "general(ized) tensor times tensor," a tensor contraction.
CuTe permits matrices to have nested `Layout`s.
This means that we can fold a `Tensor` into a "matrix" by grouping modes according to their categories.
As a result, we can implement GETT by using
our existing GEMM implementation. Included below is a launcher like `gemm_nt` that uses the same device kernel contained in `sgemm_1.cu` to compute a GETT with two m-modes.
```cpp
// Setup params for a GETT with two m-modes.
// The A and C tensors are assumed to be m0-major.
// Calls sgemm_1.cu's gemm_device<<<>>> without modification.
template <class TA, class TB, class TC,
class Alpha, class Beta>
void
gett(int m0, int m1, int n, int k,
Alpha alpha,
TA const* A, int ldAm1, int ldAk, // m0-major
TB const* B, int ldBk,
Beta beta,
TC * C, int ldCm1, int ldCn, // m0-major
cudaStream_t stream = 0)
{
using namespace cute;
// Define shapes (dynamic)
auto M = make_shape(m0, m1); // (m0,m1)-multimode M
auto N = int(n);
auto K = int(k);
auto prob_shape = make_shape(M, N, K); // (M, N, K)
// Define NT strides (mixed)
auto dA = make_stride(make_stride(Int<1>{}, ldAm1), ldAk); // (dM, dK)
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
auto dC = make_stride(make_stride(Int<1>{}, ldCm1), ldCn); // (dM, dN)
// Define CTA tile sizes (static)
auto bM = Shape<_64, _2>{}; // Take _64 elements from m0 and _2 elements from m1
auto bN = Int<128>{};
auto bK = Int< 8>{};
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
// Define the smem layouts (static)
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
// Define the thread layouts (static)
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (m,k) -> thr_idx
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (n,k) -> thr_idx
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx
dim3 dimBlock(size(tC));
dim3 dimGrid(size(ceil_div(M, bM)),
size(ceil_div(N, bN)));
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
(prob_shape, cta_tiler,
A, dA, sA, tA,
B, dB, sB, tB,
C, dC, sC, tC,
alpha, beta);
}
```
Note that the only changes are the definition of shape `M`, the definition of strides `dA` and `dC`, and the definition of the CTA Tiler `bM`. The above uses a multimodel problem shape `M = (m0,m1)` and a multimodal CTA Tiler `bM = <_64,_2>` to change which portion of the global memory tensors `A` and `C` each CTA will be responsible for computing.
Similar examples can be found for CUTLASS 3.x kernels that are based on CuTe, such as [this Hopper GETT example](../../../examples/51_hopper_gett).