cutlass/python/docs/externals/01_epilogue.ipynb

594 lines
25 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "5d24a692",
"metadata": {},
"source": [
"# Example of using elementwise activation functions in the CUTLASS Python interface\n",
"This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n",
"\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)"
]
},
{
"cell_type": "markdown",
"id": "3ca993fe",
"metadata": {},
"source": [
"We first import various packages needed for the example and construct the input and output tensors that will be used in our example."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "63a70a3c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:09.148380Z",
"iopub.status.busy": "2023-04-18T18:00:09.148011Z",
"iopub.status.idle": "2023-04-18T18:00:13.281937Z",
"shell.execute_reply": "2023-04-18T18:00:13.281256Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"import cutlass\n",
"\n",
"# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n",
"# omit this information.\n",
"print_module = True\n",
"\n",
"m = 256\n",
"n = m\n",
"k = m\n",
"\n",
"type_A = np.float16\n",
"type_B = np.float16\n",
"type_C = np.float16\n",
"type_D = np.float16\n",
"\n",
"np.random.seed(1234)\n",
"scope_min = -4\n",
"scope_max = 4\n",
"tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n",
"tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n",
"tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n",
"\n",
"alpha = np.float16(1.)\n",
"beta = np.float16(0.)\n",
"\n",
"tensor_D = np.zeros(tensor_C.shape).astype(type_D)"
]
},
{
"cell_type": "markdown",
"id": "1eb0d95b",
"metadata": {},
"source": [
"## Run a GEMM with an identity activation function\n",
"To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation `D = alpha * (A @ B) + beta * C`. This is the default activation function used, and does not need to be specified."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8d257833",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:13.284650Z",
"iopub.status.busy": "2023-04-18T18:00:13.284425Z",
"iopub.status.idle": "2023-04-18T18:00:18.333867Z",
"shell.execute_reply": "2023-04-18T18:00:18.333187Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed907287c0>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "54961694",
"metadata": {},
"source": [
"## Run a GEMM with a ReLU element-wise activation function\n",
"CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function `act`, the resulting formulation is:\n",
"```\n",
"D = alpha * (A @ B) + beta * C\n",
"D = act(D)\n",
"```\n",
"\n",
"Here, we will add a ReLU activation function. Given an input `x`, ReLU returns `max(x, 0)`.\n",
"\n",
"This is easy to do in CUTLASS. One only needs to set the plan's `activation` field."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5fe49443",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:18.337036Z",
"iopub.status.busy": "2023-04-18T18:00:18.336833Z",
"iopub.status.idle": "2023-04-18T18:00:23.482072Z",
"shell.execute_reply": "2023-04-18T18:00:23.481125Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed906f2460>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tensor_D_relu = np.zeros(tensor_C.shape).astype(type_D)\n",
"plan.activation = cutlass.epilogue.relu\n",
"plan.run(tensor_A, tensor_B, tensor_C, tensor_D_relu, print_module=print_module)"
]
},
{
"cell_type": "markdown",
"id": "455d0a37",
"metadata": {},
"source": [
"We can now verify that the result of the GEMM that used a ReLU activation function:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e32e7798",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:23.486042Z",
"iopub.status.busy": "2023-04-18T18:00:23.485342Z",
"iopub.status.idle": "2023-04-18T18:00:23.497444Z",
"shell.execute_reply": "2023-04-18T18:00:23.496668Z"
}
},
"outputs": [],
"source": [
"relu_ref = (tensor_D >= 0).astype(type_D) * tensor_D\n",
"np.testing.assert_array_equal(relu_ref, tensor_D_relu)"
]
},
{
"cell_type": "markdown",
"id": "cf959171",
"metadata": {},
"source": [
"## Other element-wise activation functions\n",
"CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the `get_activations()` method."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9e17d730",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:23.500102Z",
"iopub.status.busy": "2023-04-18T18:00:23.499944Z",
"iopub.status.idle": "2023-04-18T18:00:23.504562Z",
"shell.execute_reply": "2023-04-18T18:00:23.503793Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'cutlass.backend.epilogue.gelu'>\n",
"<class 'cutlass.backend.epilogue.hardswish'>\n",
"<class 'cutlass.backend.epilogue.identity'>\n",
"<class 'cutlass.backend.epilogue.leaky_relu'>\n",
"<class 'cutlass.backend.epilogue.relu'>\n",
"<class 'cutlass.backend.epilogue.sigmoid'>\n",
"<class 'cutlass.backend.epilogue.silu'>\n",
"<class 'cutlass.backend.epilogue.tanh'>\n"
]
}
],
"source": [
"activations = plan.activations()\n",
"for activation in activations:\n",
" print(activation)"
]
},
{
"cell_type": "markdown",
"id": "0e4599fa",
"metadata": {},
"source": [
"We can then run each of them:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9c3598c9",
"metadata": {
"execution": {
"iopub.execute_input": "2023-04-18T18:00:23.507538Z",
"iopub.status.busy": "2023-04-18T18:00:23.507257Z",
"iopub.status.idle": "2023-04-18T18:00:59.414765Z",
"shell.execute_reply": "2023-04-18T18:00:59.414116Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.gelu'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.hardswish'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::HardSwish, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.identity'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n",
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.leaky_relu'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::LeakyReLU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.relu'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n",
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.sigmoid'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Sigmoid, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.silu'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::SiLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=============================================================================================\n",
"Compiling and running activation <class 'cutlass.backend.epilogue.tanh'>\n",
"=============================================================================================\n",
"\n",
"// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8\n",
"using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =\n",
" typename cutlass::gemm::kernel::DefaultGemmUniversal<\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,\n",
" cutlass::half_t, cutlass::layout::RowMajor,\n",
" cutlass::half_t,\n",
" cutlass::arch::OpClassTensorOp,\n",
" cutlass::arch::Sm80,\n",
" cutlass::gemm::GemmShape<256, 128, 64>,\n",
" cutlass::gemm::GemmShape<64, 64, 64>,\n",
" cutlass::gemm::GemmShape<16, 8, 16>,\n",
" cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Tanh, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,\n",
" cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,\n",
" 3,\n",
" cutlass::arch::OpMultiplyAdd\n",
">::GemmKernel;\n",
"\n",
"// Define named type\n",
"struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type : \n",
" public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };\n",
"\n"
]
}
],
"source": [
"for activation in activations:\n",
" print('=============================================================================================')\n",
" print(f'Compiling and running activation {activation}')\n",
" print('=============================================================================================')\n",
" plan.activation = activation\n",
" plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "751f8d92",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}