cutlass/python/docs/externals/00_basic_gemm.ipynb

728 lines
25 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "1ef96b3f",
"metadata": {},
"source": [
"# Basic example of using the CUTLASS Python interface\n",
"This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n"
]
},
{
"cell_type": "markdown",
"id": "962324fd",
"metadata": {},
"source": [
"We first import various packages needed for the example and construct the input and output tensors that will be used in our example.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0e324219",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:39.749457Z",
"iopub.status.busy": "2023-04-18T17:59:39.748884Z",
"iopub.status.idle": "2023-04-18T17:59:43.907956Z",
"shell.execute_reply": "2023-04-18T17:59:43.907069Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import numpy as np\n",
"import random\n",
"\n",
"import cutlass\n",
"\n",
"# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n",
"# omit this information.\n",
"print_module = True\n",
"\n",
"m = 128\n",
"n = m\n",
"k = m\n",
"\n",
"dtype = np.float16\n",
"type_A = np.float16\n",
"type_B = np.float16\n",
"type_C = np.float16\n",
"type_D = np.float16\n",
"\n",
"np.random.seed(1234)\n",
"random.seed(1234)\n",
"scope_min = -4\n",
"scope_max = 4\n",
"tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
"tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
"tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
"\n",
"alpha = np.float16(1.)\n",
"beta = np.float16(0.)\n",
"\n",
"tensor_D = np.zeros(tensor_C.shape).astype(type_D)"
]
},
{
"cell_type": "markdown",
"id": "f2c7bf48",
"metadata": {},
"source": [
"## Declaring and running a GEMM\n",
"To get started, one only needs to provide the tensors declared above to the `cutlass.op.Gemm` call.\n",
"This sets up a default GEMM operation for the given device on which you are running.\n",
"\n",
"Assuming that we are running on SM80, this default to using a GEMM that leverages FP16 Tensor Core operations.\n",
"\n",
"Calling `plan.run()` will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting `print_module` to `true`, the C++ code that is emitted is printed."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0dfd8975",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:43.911740Z",
"iopub.status.busy": "2023-04-18T17:59:43.911512Z",
"iopub.status.idle": "2023-04-18T17:59:49.103941Z",
"shell.execute_reply": "2023-04-18T17:59:49.103231Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7f79cc556070>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# We specify `element_accumulator` here so as to match the kernel run by NumPy below. However,\n",
"# specifying `element_accumulator` is not required if it is the same as `element`\n",
"plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor, element_accumulator=np.float32)\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "4a5856de",
"metadata": {},
"source": [
"There are many other ways to construct a plan from `cutlass.op.Gemm` (e.g., by specifiying they types and layouts of each operand, by providing representative tensors as inputs). For more details on these, see the documentation in the `cutlass.op.Gemm` constructor."
]
},
{
"cell_type": "markdown",
"id": "945478ef",
"metadata": {},
"source": [
"We then compare the output to running the GEMM using NumPy."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6b669de6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:49.107492Z",
"iopub.status.busy": "2023-04-18T17:59:49.107284Z",
"iopub.status.idle": "2023-04-18T17:59:49.138511Z",
"shell.execute_reply": "2023-04-18T17:59:49.137837Z"
}
},
"outputs": [],
"source": [
"tensor_D_numpy = (alpha * (tensor_A @ tensor_B)) + (beta * tensor_C)\n",
"np.testing.assert_array_equal(tensor_D, tensor_D_numpy)"
]
},
{
"cell_type": "markdown",
"id": "ee5cbbbe",
"metadata": {},
"source": [
"Note that one could use the same kernel just declared for tensors provided by other frameworks beyond NumPy, such as PyTorch or CuPy."
]
},
{
"cell_type": "markdown",
"id": "b6c86493",
"metadata": {},
"source": [
"## Changing operation modes\n",
"By default, the CUTLASS Python interface will try to use Tensor Core operations whenever possible. If the configuration provided to `cutlass.op.Gemm` is not supported on Tensor Cores, the interface will fall back to using a SIMT kernel.\n",
"\n",
"The operation mode currently in use can be returned via the `plan.opclass` property. In this case Tensor Core operations."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "529fda93",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:49.141458Z",
"iopub.status.busy": "2023-04-18T17:59:49.141305Z",
"iopub.status.idle": "2023-04-18T17:59:49.145005Z",
"shell.execute_reply": "2023-04-18T17:59:49.144332Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OpcodeClass.TensorOp\n"
]
}
],
"source": [
"print(plan.opclass)"
]
},
{
"cell_type": "markdown",
"id": "6d27c575",
"metadata": {},
"source": [
"Suppose that we don't want to use Tensor Cores for this GEMM. One can change to using CUTLASS's SIMT GEMMs by setting the plan's `opclass` field.\n",
"\n",
"As is shown in the printed output, the emitted kernel uses template parameters that fit CUTLASS's SIMT GEMMs.\n",
"\n",
"Also notice that, this time around, we provided tensor parameters to `plan.run()`. One is free to provide different parameters to `plan.run()` than were passed in at the initial call to `cutlass.op.Gemm`, provided that the passed-in tensors have the same data type and layout as those passed in on intialization."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6a44d35b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:49.148548Z",
"iopub.status.busy": "2023-04-18T17:59:49.148042Z",
"iopub.status.idle": "2023-04-18T17:59:54.365792Z",
"shell.execute_reply": "2023-04-18T17:59:54.364734Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1\n",
"using cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassSimt,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<128, 128, 8>,\n",
" cutlass::gemm::GemmShape<32, 64, 8>,\n",
" cutlass::gemm::GemmShape<1, 1, 1>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 1, float, float>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 2,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1_type : \n",
" public cutlass_sm80_simt_f16_sgemm_f16_1x1x1_128x128_8x2_tt_align1_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7f7b3075abe0>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensor_D_simt = np.zeros(tensor_C.shape).astype(type_D)\n",
"plan.opclass = cutlass.OpcodeClass.Simt\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D_simt, alpha, beta, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "639dcb59",
"metadata": {},
"source": [
"If we compare the output of the Tensor Core and SIMT GEMMs we just ran we see that they are equal."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9b480853",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:54.369977Z",
"iopub.status.busy": "2023-04-18T17:59:54.369302Z",
"iopub.status.idle": "2023-04-18T17:59:54.375239Z",
"shell.execute_reply": "2023-04-18T17:59:54.374405Z"
}
},
"outputs": [],
"source": [
"np.testing.assert_array_equal(tensor_D, tensor_D_simt)"
]
},
{
"cell_type": "markdown",
"id": "0cce1eae",
"metadata": {},
"source": [
"## Running cached kernels\n",
"You may have noticed that the `plan.run()` calls for the previous two kernels took some time to execute. This is because the kernel being emitted had not yet been compiled.\n",
"\n",
"CUTLASS caches compiled binaries so that recompilation isn't necessary every time a kernel is run. For example, if we change modes back to using Tensor Cores and call `plan.run()` again (with a different set of tensor parameters), you'll find the call to return much faster."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f8051e5e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:54.378373Z",
"iopub.status.busy": "2023-04-18T17:59:54.378060Z",
"iopub.status.idle": "2023-04-18T17:59:55.220086Z",
"shell.execute_reply": "2023-04-18T17:59:55.219198Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7f7b30fb9880>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = 2400\n",
"n = 3232\n",
"k = 4096\n",
"\n",
"tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
"tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
"tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
"tensor_D = np.zeros(tensor_C.shape).astype(type_D)\n",
"\n",
"alpha = np.float16(1.)\n",
"beta = np.float16(2.)\n",
"\n",
"plan.opclass = cutlass.OpcodeClass.TensorOp\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "52a4e318",
"metadata": {},
"source": [
"## Running non-default GEMMs\n",
"The previous examples showed how it is simple to get started running a default GEMM kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the GEMM?\n",
"\n",
"Under the hood, CUTLASS enumerates the different GEMM configuration parameters possible for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernels (e.g., cluster, threadblock, and warp shape)."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1c593be1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:55.223812Z",
"iopub.status.busy": "2023-04-18T17:59:55.223651Z",
"iopub.status.idle": "2023-04-18T17:59:55.228769Z",
"shell.execute_reply": "2023-04-18T17:59:55.228101Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"132 tile descriptions returned\n",
"First 10 tile descriptions are:\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 128, 64]\n",
" WarpCount: [4, 2, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 256, 64]\n",
" WarpCount: [2, 4, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 128, 64]\n",
" WarpCount: [4, 2, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 256, 64]\n",
" WarpCount: [2, 4, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 128, 32]\n",
" WarpCount: [4, 2, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 256, 32]\n",
" WarpCount: [2, 4, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 64, 64]\n",
" WarpCount: [4, 1, 1]\n",
" Stages: 4\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [64, 256, 64]\n",
" WarpCount: [1, 4, 1]\n",
" Stages: 4\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 128, 64]\n",
" WarpCount: [2, 2, 1]\n",
" Stages: 4\n",
" Kernel schedule: ScheduleAuto\n",
"}\n",
"\n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [256, 64, 64]\n",
" WarpCount: [4, 1, 1]\n",
" Stages: 3\n",
" Kernel schedule: ScheduleAuto\n",
"}\n"
]
}
],
"source": [
"tiles = plan.tile_descriptions()\n",
"print('{} tile descriptions returned'.format(len(tiles)))\n",
"num_print = 10\n",
"print('First {} tile descriptions are:'.format(num_print))\n",
"for td in tiles[:num_print]:\n",
" print(td)"
]
},
{
"cell_type": "markdown",
"id": "dc3ad875",
"metadata": {},
"source": [
"Next, we'll pick one of these configurations at random and compile and run it."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a8dc5287",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T17:59:55.231498Z",
"iopub.status.busy": "2023-04-18T17:59:55.230924Z",
"iopub.status.idle": "2023-04-18T18:00:00.340161Z",
"shell.execute_reply": "2023-04-18T18:00:00.339603Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tile description 112 is: \n",
"{\n",
" ClusterShape: [1, 1, 1]\n",
" ThreadblockShape: [128, 128, 32]\n",
" WarpCount: [2, 2, 1]\n",
" Stages: 4\n",
" Kernel schedule: ScheduleAuto\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8\n",
"using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<128, 128, 32>,\n",
" cutlass::gemm::GemmShape<64, 64, 32>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 4,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_type : \n",
" public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7f79cc58de20>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idx = random.randint(0, len(tiles)-1)\n",
"td = tiles[idx]\n",
"print('Tile description {} is: {}'.format(idx, td))\n",
"plan.compile(td)\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "c5a8b534",
"metadata": {},
"source": [
"One can also change the swizzling function used by the kernel. For example, one can modify the kernel to use the stream K feature of CUTLASS via:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e5e88d17",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:00.343772Z",
"iopub.status.busy": "2023-04-18T18:00:00.343582Z",
"iopub.status.idle": "2023-04-18T18:00:06.192256Z",
"shell.execute_reply": "2023-04-18T18:00:06.191286Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8\n",
"using cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" float,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<128, 128, 32>,\n",
" cutlass::gemm::GemmShape<64, 64, 32>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,\n",
" cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,\n",
" 4,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_type : \n",
" public cutlass_sm80_tensorop_f16_s16x8x16gemm_f16_1x1x1_128x128_32x4_tt_align8_base { };\n",
"\n"
]
}
],
"source": [
"# Stream K is only supported pre-SM90 (at least when this example was written)\n",
"if plan.cc != 90:\n",
" plan.swizzling_functor = cutlass.swizzle.ThreadblockSwizzleStreamK\n",
" plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "5a8ba2ba",
"metadata": {},
"source": [
"## Handling errors\n",
"The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n",
"\n",
"Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fe7d0e42",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:06.196345Z",
"iopub.status.busy": "2023-04-18T18:00:06.195784Z",
"iopub.status.idle": "2023-04-18T18:00:06.199248Z",
"shell.execute_reply": "2023-04-18T18:00:06.198438Z"
}
},
"outputs": [],
"source": [
"# td = tiles[0]\n",
"# td.stages = 8\n",
"# plan.compile(td)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "0466d96796c9cd8f7a1cad264ff326ececc950ba2420e0256d5105fc1a3c6e70"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}