301 lines
9.8 KiB
Plaintext
301 lines
9.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "6acbea5d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension\n",
|
|
"This notebook walks through a basic example of using the CUTLASS Python interface to declare\n",
|
|
"a grouped GEMM kernel and export it as a PyTorch CUDA extension. Note that GEMM and Conv2d can also be exported as PyTorch CUDA extensions. \n",
|
|
"\n",
|
|
"[](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/02_pytorch_extension_grouped_gemm.ipynb)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2d70560e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Prerequisites for running on Colab\n",
|
|
"This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cc7c7458",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!#nvidia-smi"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2107bb0d",
|
|
"metadata": {},
|
|
"source": [
|
|
"If running on Colab, you will need to install the CUTLASS Python interface and PyTorch. To do so, uncomment the following line and run the cell:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a9852cb8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!#pip install nvidia-cutlass torch --extra-index-url https://download.pytorch.org/whl/cu121"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "962324fd",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Background on grouped GEMM\n",
|
|
"Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides)\n",
|
|
"in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM,\n",
|
|
"without the requirement that the sizes and strides of each GEMM be the same.\n",
|
|
"\n",
|
|
"For example, if one has `p` GEMMs with sizes:\n",
|
|
"```text\n",
|
|
"M_1 x N_1 x K_1\n",
|
|
"M_2 x N_2 x K_2\n",
|
|
"...\n",
|
|
"M_p x N_p x K_p\n",
|
|
"```\n",
|
|
"CUTLASS's grouped GEMM will execute these in a single CUDA kernel.\n",
|
|
"\n",
|
|
"Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would\n",
|
|
"insufficiently utilize the device in isolation.\n",
|
|
"\n",
|
|
"## Declaring a grouped GEMM via the CUTLASS Python interface\n",
|
|
"A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one\n",
|
|
"simply calls `cutlass.op.GroupedGemm`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fdcf21d8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import cutlass\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"dtype = torch.float16\n",
|
|
"plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "514f40a4",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can then compile and run this operation on a group of GEMMs. We'll first set up some utility functions to initialize GEMMs."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c2a7371e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import random\n",
|
|
"random.seed(2023)\n",
|
|
"\n",
|
|
"# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K\n",
|
|
"def initialize(dtype, M, N, K):\n",
|
|
" sizes = [(M, K), (K, N), (M, N), (M, N)]\n",
|
|
" return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]\n",
|
|
"\n",
|
|
"# Utility function to generate `problems` GEMMs of random sizes\n",
|
|
"def generate_problems(problems):\n",
|
|
" valid_sizes = [128, 256, 512, 1024]\n",
|
|
" As, Bs, Cs, Ds = [], [], [], []\n",
|
|
" for _ in range(problems):\n",
|
|
" M, N, K = [random.choice(valid_sizes) for _ in range(3)]\n",
|
|
" A, B, C, D = initialize(dtype, M, N, K)\n",
|
|
" As.append(A)\n",
|
|
" Bs.append(B)\n",
|
|
" Cs.append(C)\n",
|
|
" Ds.append(D)\n",
|
|
" return As, Bs, Cs, Ds"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "590a3bc5",
|
|
"metadata": {},
|
|
"source": [
|
|
"We'll next run a group of 20 GEMMs via the CUTLASS Python interface and via PyTorch."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "776c9233",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"As, Bs, Cs, Ds, = generate_problems(20)\n",
|
|
"\n",
|
|
"plan.run(As, Bs, Cs, Ds, print_module=True)\n",
|
|
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
|
"\n",
|
|
"for d, d_torch in zip(Ds, Ds_torch):\n",
|
|
" assert torch.allclose(d, d_torch)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "766e4f03",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Exporting the CUTLASS kernel to a PyTorch CUDA extension\n",
|
|
"The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a [PyTorch CUDA extension](https://pytorch.org/tutorials/advanced/cpp_extension.html). This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.\n",
|
|
"\n",
|
|
"The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later \"ahead-of-time\" compilation, or be just-in-time compiled and returned to the user.\n",
|
|
"\n",
|
|
"To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3a98dee6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"op = plan.construct()\n",
|
|
"grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c8ca3991",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `cutlass.emit.pytorch` function emits:\n",
|
|
"* `out/grouped_gemm_kernel.cu`: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors\n",
|
|
"* `out/grouped_gemm.cpp`: This file contains a C++ wrapper around the aforementioned CUTLASS kernel\n",
|
|
"* `setup.py`: This file contains the `setuptools` script for building and installing the generated extension\n",
|
|
"\n",
|
|
"The extension can be build from within the `module_output` directory by running:\n",
|
|
"```bash\n",
|
|
"TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py install\n",
|
|
"```\n",
|
|
"Where `TORCH_ARCH_LIST` is set to the compute capability of the device on which the kernel will be run.\n",
|
|
"\n",
|
|
"See the PyTorch [\"Custom C++ and CUDA Extensions\"](https://pytorch.org/tutorials/advanced/cpp_extension.html) tutorial for more details on this.\n",
|
|
"\n",
|
|
"The PyTorch CUDA extension could be built for this module by running:\n",
|
|
"```bash\n",
|
|
"cd out\n",
|
|
"TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py\n",
|
|
"```\n",
|
|
"(assuming that one is building for SM80)\n",
|
|
"\n",
|
|
"One could then use the kernel in a later PyTorch module by running:\n",
|
|
"\n",
|
|
"```python\n",
|
|
"import torch\n",
|
|
"import grouped_gemm\n",
|
|
"\n",
|
|
"grouped_gemm.run(As, Bs)\n",
|
|
"```\n",
|
|
"\n",
|
|
"In this case, however, we set `jit=True`, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly.\n",
|
|
"Under the hood, this leverages the [torch.utils.cpp_extension.load](https://pytorch.org/tutorials/advanced/cpp_extension.html) method\n",
|
|
"and returns back the loaded extension.\n",
|
|
"\n",
|
|
"We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cecb26a4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"Ds = grouped_gemm.run(As, Bs)\n",
|
|
"Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
|
"for d, d_torch in zip(Ds, Ds_torch):\n",
|
|
" assert torch.allclose(d, d_torch)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "50db80e4",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finally, we can profile our grouped GEMM extension:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b76805d3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"num_warmup = 20\n",
|
|
"num_profile = 100\n",
|
|
"\n",
|
|
"# Warmup iterations\n",
|
|
"for _ in range(num_warmup):\n",
|
|
" Ds = grouped_gemm.run(As, Bs)\n",
|
|
" Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
|
" torch.cuda.synchronize()\n",
|
|
"\n",
|
|
"# Timing iterations\n",
|
|
"import time\n",
|
|
"grouped = 0\n",
|
|
"nongrouped = 0\n",
|
|
"for _ in range(num_profile):\n",
|
|
" start = time.time()\n",
|
|
" Ds = grouped_gemm.run(As, Bs)\n",
|
|
" torch.cuda.synchronize()\n",
|
|
" grouped += time.time() - start\n",
|
|
"\n",
|
|
" start = time.time()\n",
|
|
" Ds_torch = [a @ b for a, b in zip(As, Bs)]\n",
|
|
" torch.cuda.synchronize()\n",
|
|
" nongrouped += time.time() - start\n",
|
|
"\n",
|
|
"print('Grouped: {:.3f} us'.format(grouped * 1e6/num_profile))\n",
|
|
"print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))\n",
|
|
"print('Speedup: {:.3f}'.format(nongrouped / grouped))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|