cutlass/python/docs/externals/02_pytorch_extension_groupe...

537 lines
40 KiB
HTML
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!doctype html>
<html class="no-js" lang="en">
<head><meta charset="utf-8"/>
<meta name="viewport" content="width=device-width,initial-scale=1"/>
<meta name="color-scheme" content="light dark"><meta name="generator" content="Docutils 0.19: https://docutils.sourceforge.io/" />
<link rel="index" title="Index" href="../genindex.html" /><link rel="search" title="Search" href="../search.html" /><link rel="prev" title="Example of using elementwise activation functions in the CUTLASS Python interface" href="01_epilogue.html" />
<link rel="canonical" href="docs/externals/02_pytorch_extension_grouped_gemm.html" />
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
<title>Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension - CUTLASS Python</title>
<link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css" />
<link rel="stylesheet" type="text/css" href="../_static/tabs.css" />
<link rel="stylesheet" type="text/css" href="../_static/nbsphinx-code-cells.css" />
<link rel="stylesheet" type="text/css" href="../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
<style>
body {
--color-code-background: #eeffcc;
--color-code-foreground: black;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
@media not print {
body[data-theme="dark"] {
--color-code-background: #272822;
--color-code-foreground: #f8f8f2;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
@media (prefers-color-scheme: dark) {
body:not([data-theme="light"]) {
--color-code-background: #272822;
--color-code-foreground: #f8f8f2;
--color-brand-primary: #76B900;
--color-brand-content: #76B900;
}
}
}
</style></head>
<body>
<script>
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
</script>
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
<symbol id="svg-toc" viewBox="0 0 24 24">
<title>Contents</title>
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
</svg>
</symbol>
<symbol id="svg-menu" viewBox="0 0 24 24">
<title>Menu</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
<line x1="3" y1="12" x2="21" y2="12"></line>
<line x1="3" y1="6" x2="21" y2="6"></line>
<line x1="3" y1="18" x2="21" y2="18"></line>
</svg>
</symbol>
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
<title>Expand</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
<polyline points="9 18 15 12 9 6"></polyline>
</svg>
</symbol>
<symbol id="svg-sun" viewBox="0 0 24 24">
<title>Light mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
<circle cx="12" cy="12" r="5"></circle>
<line x1="12" y1="1" x2="12" y2="3"></line>
<line x1="12" y1="21" x2="12" y2="23"></line>
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
<line x1="1" y1="12" x2="3" y2="12"></line>
<line x1="21" y1="12" x2="23" y2="12"></line>
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
</svg>
</symbol>
<symbol id="svg-moon" viewBox="0 0 24 24">
<title>Dark mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
</svg>
</symbol>
<symbol id="svg-sun-half" viewBox="0 0 24 24">
<title>Auto light/dark mode</title>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
<circle cx="12" cy="12" r="9" />
<path d="M13 12h5" />
<path d="M13 15h4" />
<path d="M13 18h1" />
<path d="M13 9h4" />
<path d="M13 6h1" />
</svg>
</symbol>
</svg>
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
<label class="overlay sidebar-overlay" for="__navigation">
<div class="visually-hidden">Hide navigation sidebar</div>
</label>
<label class="overlay toc-overlay" for="__toc">
<div class="visually-hidden">Hide table of contents sidebar</div>
</label>
<div class="page">
<header class="mobile-header">
<div class="header-left">
<label class="nav-overlay-icon" for="__navigation">
<div class="visually-hidden">Toggle site navigation sidebar</div>
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
</label>
</div>
<div class="header-center">
<a href="../index.html"><div class="brand">CUTLASS Python</div></a>
</div>
<div class="header-right">
<div class="theme-toggle-container theme-toggle-header">
<button class="theme-toggle">
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
</button>
</div>
<label class="toc-overlay-icon toc-header-icon" for="__toc">
<div class="visually-hidden">Toggle table of contents sidebar</div>
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
</label>
</div>
</header>
<aside class="sidebar-drawer">
<div class="sidebar-container">
<div class="sidebar-sticky"><a class="sidebar-brand" href="../index.html">
<div class="sidebar-logo-container">
<img class="sidebar-logo only-light" src="../_static/cutlass-logo-small.png" alt="Light Logo"/>
<img class="sidebar-logo only-dark" src="../_static/cutlass-logo-small.png" alt="Dark Logo"/>
</div>
<span class="sidebar-brand-text">CUTLASS Python</span>
</a><form class="sidebar-search-container" method="get" action="../search.html" role="search">
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
<input type="hidden" name="check_keywords" value="yes">
<input type="hidden" name="area" value="default">
</form>
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../index.html">Home</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../install.html">Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="00_basic_gemm.html">Getting Started</a></li>
<li class="toctree-l1"><a class="reference internal" href="../contribute.html">Contributing</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
<ul>
<li class="toctree-l1 has-children"><a class="reference internal" href="../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
<li class="toctree-l2 has-children"><a class="reference internal" href="../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.emit.html">Emitters</a></li>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.op.html">Operations</a></li>
<li class="toctree-l3"><a class="reference internal" href="../cutlass.utils.html">Utilities</a></li>
</ul>
</li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
<ul class="current">
<li class="toctree-l1 current has-children"><a class="reference internal" href="../examples.html">Examples</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="00_basic_gemm.html">Basic GEMM</a></li>
<li class="toctree-l2"><a class="reference internal" href="01_epilogue.html">Epilogue</a></li>
<li class="toctree-l2 current current-page"><a class="current reference internal" href="#">PyTorch Extension</a></li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
<ul>
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
</ul>
</div>
</div>
</div>
</div>
</aside>
<div class="main">
<div class="content">
<div class="article-container">
<a href="#" class="back-to-top muted-link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
</svg>
<span>Back to top</span>
</a>
<div class="content-icon-container">
<div class="theme-toggle-container theme-toggle-content">
<button class="theme-toggle">
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
</button>
</div>
<label class="toc-overlay-icon toc-content-icon" for="__toc">
<div class="visually-hidden">Toggle table of contents sidebar</div>
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
</label>
</div>
<article role="main">
<section id="Exporting-a-CUTLASS-grouped-GEMM-kernel-to-a-PyTorch-CUDA-extension">
<h1>Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension<a class="headerlink" href="#Exporting-a-CUTLASS-grouped-GEMM-kernel-to-a-PyTorch-CUDA-extension" title="Permalink to this heading">#</a></h1>
<p>This notebook walks through a basic example of using the CUTLASS Python interface to declare a grouped GEMM kernel and export it as a PyTorch CUDA extension.</p>
<p><a class="reference external" href="https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a></p>
<section id="Background-on-grouped-GEMM">
<h2>Background on grouped GEMM<a class="headerlink" href="#Background-on-grouped-GEMM" title="Permalink to this heading">#</a></h2>
<p>Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides) in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM, without the requirement that the sizes and strides of each GEMM be the same.</p>
<p>For example, if one has <code class="docutils literal notranslate"><span class="pre">p</span></code> GEMMs with sizes:</p>
<div class="highlight-text notranslate"><div class="highlight"><pre><span></span>M_1 x N_1 x K_1
M_2 x N_2 x K_2
...
M_p x N_p x K_p
</pre></div>
</div>
<p>CUTLASSs grouped GEMM will execute these in a single CUDA kernel.</p>
<p>Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would insufficiently utilize the device in isolation.</p>
</section>
<section id="Declaring-a-grouped-GEMM-via-the-CUTLASS-Python-interface">
<h2>Declaring a grouped GEMM via the CUTLASS Python interface<a class="headerlink" href="#Declaring-a-grouped-GEMM-via-the-CUTLASS-Python-interface" title="Permalink to this heading">#</a></h2>
<p>A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one simply calls <code class="docutils literal notranslate"><span class="pre">cutlass.op.GroupedGemm</span></code>.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">cutlass</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span>
<span class="n">plan</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">GroupedGemm</span><span class="p">(</span><span class="n">element</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area stderr docutils container">
<div class="highlight"><pre>
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
</pre></div></div>
</div>
<p>We can then compile and run this operation on a group of GEMMs. Well first set up some utility functions to initialize GEMMs.</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">random</span>
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">2023</span><span class="p">)</span>
<span class="c1"># Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K</span>
<span class="k">def</span> <span class="nf">initialize</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">):</span>
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[(</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)]</span>
<span class="k">return</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">sizes</span><span class="p">]</span>
<span class="c1"># Utility function to generate `problems` GEMMs of random sizes</span>
<span class="k">def</span> <span class="nf">generate_problems</span><span class="p">(</span><span class="n">problems</span><span class="p">):</span>
<span class="n">valid_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">1024</span><span class="p">]</span>
<span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">problems</span><span class="p">):</span>
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="p">[</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">valid_sizes</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">)]</span>
<span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">initialize</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span>
<span class="n">As</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
<span class="n">Bs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
<span class="n">Cs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">C</span><span class="p">)</span>
<span class="n">Ds</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">D</span><span class="p">)</span>
<span class="k">return</span> <span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span>
</pre></div>
</div>
</div>
<p>Well next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch.</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span><span class="p">,</span> <span class="o">=</span> <span class="n">generate_problems</span><span class="p">(</span><span class="mi">50</span><span class="p">)</span>
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
<span class="k">for</span> <span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">Ds</span><span class="p">,</span> <span class="n">Ds_torch</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span><span class="p">)</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmGrouped&lt;
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape&lt;256, 128, 64&gt;,
cutlass::gemm::GemmShape&lt;64, 64, 64&gt;,
cutlass::gemm::GemmShape&lt;16, 8, 16&gt;,
cutlass::epilogue::thread::LinearCombination&lt;cutlass::half_t, 8, cutlass::half_t, cutlass::half_t&gt;,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle&lt;1&gt;,
3,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
cutlass::arch::OpMultiplyAdd
&gt;::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base { };
</pre></div></div>
</div>
</section>
<section id="Exporting-the-CUTLASS-kernel-to-a-PyTorch-CUDA-extension">
<h2>Exporting the CUTLASS kernel to a PyTorch CUDA extension<a class="headerlink" href="#Exporting-the-CUTLASS-kernel-to-a-PyTorch-CUDA-extension" title="Permalink to this heading">#</a></h2>
<p>The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">PyTorch CUDA extension</a>. This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.</p>
<p>The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later “ahead-of-time” compilation, or be just-in-time compiled and returned to the user.</p>
<p>To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">op</span> <span class="o">=</span> <span class="n">plan</span><span class="o">.</span><span class="n">construct</span><span class="p">()</span>
<span class="n">grouped_gemm</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">emit</span><span class="o">.</span><span class="n">pytorch</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;grouped_gemm&#39;</span><span class="p">,</span> <span class="n">cc</span><span class="o">=</span><span class="n">plan</span><span class="o">.</span><span class="n">cc</span><span class="p">,</span> <span class="n">sourcedir</span><span class="o">=</span><span class="s1">&#39;out&#39;</span><span class="p">,</span> <span class="n">jit</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
</div>
<p>The <code class="docutils literal notranslate"><span class="pre">cutlass.emit.pytorch</span></code> function emits: * <code class="docutils literal notranslate"><span class="pre">out/grouped_gemm_kernel.cu</span></code>: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors * <code class="docutils literal notranslate"><span class="pre">out/grouped_gemm.cpp</span></code>: This file contains a C++ wrapper around the aforementioned CUTLASS kernel * <code class="docutils literal notranslate"><span class="pre">setup.py</span></code>: This file contains the <code class="docutils literal notranslate"><span class="pre">setuptools</span></code> script for building and installing the generated extension</p>
<p>The extension can be build from within the <code class="docutils literal notranslate"><span class="pre">module_output</span></code> directory by running:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nv">TORCH_CUDA_ARCH_LIST</span><span class="o">=</span><span class="s2">&quot;8.0&quot;</span><span class="w"> </span>python<span class="w"> </span>setup.py<span class="w"> </span>install
</pre></div>
</div>
<p>Where <code class="docutils literal notranslate"><span class="pre">TORCH_ARCH_LIST</span></code> is set to the compute capability of the device on which the kernel will be run.</p>
<p>See the PyTorch <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">“Custom C++ and CUDA Extensions”</a> tutorial for more details on this.</p>
<p>The PyTorch CUDA extension could be built for this module by running:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span><span class="w"> </span>out
<span class="nv">TORCH_CUDA_ARCH_LIST</span><span class="o">=</span><span class="s2">&quot;8.0&quot;</span><span class="w"> </span>python<span class="w"> </span>setup.py
</pre></div>
</div>
<p>(assuming that one is building for SM80)</p>
<p>One could then use the kernel in a later PyTorch module by running:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">grouped_gemm</span>
<span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
</pre></div>
</div>
<p>In this case, however, we set <code class="docutils literal notranslate"><span class="pre">jit=True</span></code>, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly. Under the hood, this leverages the <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">torch.utils.cpp_extension.load</a> method and returns back the loaded extension.</p>
<p>We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:</p>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[5]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">Ds</span> <span class="o">=</span> <span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
<span class="k">for</span> <span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">Ds</span><span class="p">,</span> <span class="n">Ds_torch</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span><span class="p">)</span>
</pre></div>
</div>
</div>
<p>Finally, we can profile our grouped GEMM extension:</p>
<div class="nbinput docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[6]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">num_warmup</span> <span class="o">=</span> <span class="mi">20</span>
<span class="n">num_profile</span> <span class="o">=</span> <span class="mi">100</span>
<span class="c1"># Warmup iterations</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_warmup</span><span class="p">):</span>
<span class="n">Ds</span> <span class="o">=</span> <span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="c1"># Timing iterations</span>
<span class="kn">import</span> <span class="nn">time</span>
<span class="n">grouped</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">nongrouped</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_profile</span><span class="p">):</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">Ds</span> <span class="o">=</span> <span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">grouped</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">nongrouped</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Grouped: </span><span class="si">{:.3f}</span><span class="s1"> us&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">grouped</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="n">num_profile</span><span class="p">))</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Non-Grouped: </span><span class="si">{:.3f}</span><span class="s1"> us&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">nongrouped</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="n">num_profile</span><span class="p">))</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Speedup: </span><span class="si">{:.3f}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">nongrouped</span> <span class="o">/</span> <span class="n">grouped</span><span class="p">))</span>
</pre></div>
</div>
</div>
<div class="nboutput nblast docutils container">
<div class="prompt empty docutils container">
</div>
<div class="output_area docutils container">
<div class="highlight"><pre>
Grouped: 400.696 us
Non-Grouped: 646.670 us
Speedup: 1.614
</pre></div></div>
</div>
<div class="nbinput nblast docutils container">
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[ ]:
</pre></div>
</div>
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>
</pre></div>
</div>
</div>
</section>
</section>
</article>
</div>
<footer>
<div class="related-pages">
<a class="prev-page" href="01_epilogue.html">
<svg class="furo-related-icon"><use href="#svg-arrow-right"></use></svg>
<div class="page-info">
<div class="context">
<span>Previous</span>
</div>
<div class="title">Example of using elementwise activation functions in the CUTLASS Python interface</div>
</div>
</a>
</div>
<div class="bottom-of-page">
<div class="left-details">
<div class="copyright">
Copyright &#169; 2023, NVIDIA
</div>
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
<a href="https://github.com/pradyunsg/furo">Furo</a>
</div>
<div class="right-details">
<div class="icons">
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
</svg>
</a>
</div>
</div>
</div>
</footer>
</div>
<aside class="toc-drawer">
<div class="toc-sticky toc-scroll">
<div class="toc-title-container">
<span class="toc-title">
On this page
</span>
</div>
<div class="toc-tree-container">
<div class="toc-tree">
<ul>
<li><a class="reference internal" href="#">Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension</a><ul>
<li><a class="reference internal" href="#Background-on-grouped-GEMM">Background on grouped GEMM</a></li>
<li><a class="reference internal" href="#Declaring-a-grouped-GEMM-via-the-CUTLASS-Python-interface">Declaring a grouped GEMM via the CUTLASS Python interface</a></li>
<li><a class="reference internal" href="#Exporting-the-CUTLASS-kernel-to-a-PyTorch-CUDA-extension">Exporting the CUTLASS kernel to a PyTorch CUDA extension</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
</aside>
</div>
</div><script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/sphinx_highlight.js"></script>
<script src="../_static/scripts/furo.js"></script>
<script src="../_static/clipboard.min.js"></script>
<script src="../_static/copybutton.js"></script>
<script src="../_static/tabs.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script>window.MathJax = {"tex": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true}, "options": {"ignoreHtmlClass": "tex2jax_ignore|mathjax_ignore|document", "processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
<script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
</body>
</html>