537 lines
40 KiB
HTML
537 lines
40 KiB
HTML
<!doctype html>
|
||
<html class="no-js" lang="en">
|
||
<head><meta charset="utf-8"/>
|
||
<meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||
<meta name="color-scheme" content="light dark"><meta name="generator" content="Docutils 0.19: https://docutils.sourceforge.io/" />
|
||
<link rel="index" title="Index" href="../genindex.html" /><link rel="search" title="Search" href="../search.html" /><link rel="prev" title="Example of using elementwise activation functions in the CUTLASS Python interface" href="01_epilogue.html" />
|
||
<link rel="canonical" href="docs/externals/02_pytorch_extension_grouped_gemm.html" />
|
||
|
||
<!-- Generated with Sphinx 6.1.3 and Furo 2023.03.27 -->
|
||
<title>Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension - CUTLASS Python</title>
|
||
<link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/styles/furo.css?digest=fad236701ea90a88636c2a8c73b44ae642ed2a53" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/copybutton.css" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/tabs.css" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/nbsphinx-code-cells.css" />
|
||
<link rel="stylesheet" type="text/css" href="../_static/styles/furo-extensions.css?digest=30d1aed668e5c3a91c3e3bf6a60b675221979f0e" />
|
||
|
||
|
||
|
||
|
||
<style>
|
||
body {
|
||
--color-code-background: #eeffcc;
|
||
--color-code-foreground: black;
|
||
--color-brand-primary: #76B900;
|
||
--color-brand-content: #76B900;
|
||
|
||
}
|
||
@media not print {
|
||
body[data-theme="dark"] {
|
||
--color-code-background: #272822;
|
||
--color-code-foreground: #f8f8f2;
|
||
--color-brand-primary: #76B900;
|
||
--color-brand-content: #76B900;
|
||
|
||
}
|
||
@media (prefers-color-scheme: dark) {
|
||
body:not([data-theme="light"]) {
|
||
--color-code-background: #272822;
|
||
--color-code-foreground: #f8f8f2;
|
||
--color-brand-primary: #76B900;
|
||
--color-brand-content: #76B900;
|
||
|
||
}
|
||
}
|
||
}
|
||
</style></head>
|
||
<body>
|
||
|
||
<script>
|
||
document.body.dataset.theme = localStorage.getItem("theme") || "auto";
|
||
</script>
|
||
|
||
|
||
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
|
||
<symbol id="svg-toc" viewBox="0 0 24 24">
|
||
<title>Contents</title>
|
||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 1024 1024">
|
||
<path d="M408 442h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8zm-8 204c0 4.4 3.6 8 8 8h480c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8H408c-4.4 0-8 3.6-8 8v56zm504-486H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zm0 632H120c-4.4 0-8 3.6-8 8v56c0 4.4 3.6 8 8 8h784c4.4 0 8-3.6 8-8v-56c0-4.4-3.6-8-8-8zM115.4 518.9L271.7 642c5.8 4.6 14.4.5 14.4-6.9V388.9c0-7.4-8.5-11.5-14.4-6.9L115.4 505.1a8.74 8.74 0 0 0 0 13.8z"/>
|
||
</svg>
|
||
</symbol>
|
||
<symbol id="svg-menu" viewBox="0 0 24 24">
|
||
<title>Menu</title>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-menu">
|
||
<line x1="3" y1="12" x2="21" y2="12"></line>
|
||
<line x1="3" y1="6" x2="21" y2="6"></line>
|
||
<line x1="3" y1="18" x2="21" y2="18"></line>
|
||
</svg>
|
||
</symbol>
|
||
<symbol id="svg-arrow-right" viewBox="0 0 24 24">
|
||
<title>Expand</title>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather-chevron-right">
|
||
<polyline points="9 18 15 12 9 6"></polyline>
|
||
</svg>
|
||
</symbol>
|
||
<symbol id="svg-sun" viewBox="0 0 24 24">
|
||
<title>Light mode</title>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather-sun">
|
||
<circle cx="12" cy="12" r="5"></circle>
|
||
<line x1="12" y1="1" x2="12" y2="3"></line>
|
||
<line x1="12" y1="21" x2="12" y2="23"></line>
|
||
<line x1="4.22" y1="4.22" x2="5.64" y2="5.64"></line>
|
||
<line x1="18.36" y1="18.36" x2="19.78" y2="19.78"></line>
|
||
<line x1="1" y1="12" x2="3" y2="12"></line>
|
||
<line x1="21" y1="12" x2="23" y2="12"></line>
|
||
<line x1="4.22" y1="19.78" x2="5.64" y2="18.36"></line>
|
||
<line x1="18.36" y1="5.64" x2="19.78" y2="4.22"></line>
|
||
</svg>
|
||
</symbol>
|
||
<symbol id="svg-moon" viewBox="0 0 24 24">
|
||
<title>Dark mode</title>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-moon">
|
||
<path stroke="none" d="M0 0h24v24H0z" fill="none" />
|
||
<path d="M12 3c.132 0 .263 0 .393 0a7.5 7.5 0 0 0 7.92 12.446a9 9 0 1 1 -8.313 -12.454z" />
|
||
</svg>
|
||
</symbol>
|
||
<symbol id="svg-sun-half" viewBox="0 0 24 24">
|
||
<title>Auto light/dark mode</title>
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor"
|
||
stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="icon-tabler-shadow">
|
||
<path stroke="none" d="M0 0h24v24H0z" fill="none"/>
|
||
<circle cx="12" cy="12" r="9" />
|
||
<path d="M13 12h5" />
|
||
<path d="M13 15h4" />
|
||
<path d="M13 18h1" />
|
||
<path d="M13 9h4" />
|
||
<path d="M13 6h1" />
|
||
</svg>
|
||
</symbol>
|
||
</svg>
|
||
|
||
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation">
|
||
<input type="checkbox" class="sidebar-toggle" name="__toc" id="__toc">
|
||
<label class="overlay sidebar-overlay" for="__navigation">
|
||
<div class="visually-hidden">Hide navigation sidebar</div>
|
||
</label>
|
||
<label class="overlay toc-overlay" for="__toc">
|
||
<div class="visually-hidden">Hide table of contents sidebar</div>
|
||
</label>
|
||
|
||
|
||
|
||
<div class="page">
|
||
<header class="mobile-header">
|
||
<div class="header-left">
|
||
<label class="nav-overlay-icon" for="__navigation">
|
||
<div class="visually-hidden">Toggle site navigation sidebar</div>
|
||
<i class="icon"><svg><use href="#svg-menu"></use></svg></i>
|
||
</label>
|
||
</div>
|
||
<div class="header-center">
|
||
<a href="../index.html"><div class="brand">CUTLASS Python</div></a>
|
||
</div>
|
||
<div class="header-right">
|
||
<div class="theme-toggle-container theme-toggle-header">
|
||
<button class="theme-toggle">
|
||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||
</button>
|
||
</div>
|
||
<label class="toc-overlay-icon toc-header-icon" for="__toc">
|
||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||
</label>
|
||
</div>
|
||
</header>
|
||
<aside class="sidebar-drawer">
|
||
<div class="sidebar-container">
|
||
|
||
<div class="sidebar-sticky"><a class="sidebar-brand" href="../index.html">
|
||
|
||
<div class="sidebar-logo-container">
|
||
<img class="sidebar-logo only-light" src="../_static/cutlass-logo-small.png" alt="Light Logo"/>
|
||
<img class="sidebar-logo only-dark" src="../_static/cutlass-logo-small.png" alt="Dark Logo"/>
|
||
</div>
|
||
|
||
<span class="sidebar-brand-text">CUTLASS Python</span>
|
||
|
||
</a><form class="sidebar-search-container" method="get" action="../search.html" role="search">
|
||
<input class="sidebar-search" placeholder="Search" name="q" aria-label="Search">
|
||
<input type="hidden" name="check_keywords" value="yes">
|
||
<input type="hidden" name="area" value="default">
|
||
</form>
|
||
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../index.html">Home</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Getting Started:</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference internal" href="../install.html">Installation</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="00_basic_gemm.html">Getting Started</a></li>
|
||
<li class="toctree-l1"><a class="reference internal" href="../contribute.html">Contributing</a></li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Python Documentation:</span></p>
|
||
<ul>
|
||
<li class="toctree-l1 has-children"><a class="reference internal" href="../modules.html">CUTLASS Python API</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" role="switch" type="checkbox"/><label for="toctree-checkbox-1"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||
<li class="toctree-l2 has-children"><a class="reference internal" href="../cutlass.html">CUTLASS</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" role="switch" type="checkbox"/><label for="toctree-checkbox-2"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul>
|
||
<li class="toctree-l3"><a class="reference internal" href="../cutlass.emit.html">Emitters</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../cutlass.op.html">Operations</a></li>
|
||
<li class="toctree-l3"><a class="reference internal" href="../cutlass.utils.html">Utilities</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Examples and Tutorials:</span></p>
|
||
<ul class="current">
|
||
<li class="toctree-l1 current has-children"><a class="reference internal" href="../examples.html">Examples</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" role="switch" type="checkbox"/><label for="toctree-checkbox-3"><div class="visually-hidden">Toggle child pages in navigation</div><i class="icon"><svg><use href="#svg-arrow-right"></use></svg></i></label><ul class="current">
|
||
<li class="toctree-l2"><a class="reference internal" href="00_basic_gemm.html">Basic GEMM</a></li>
|
||
<li class="toctree-l2"><a class="reference internal" href="01_epilogue.html">Epilogue</a></li>
|
||
<li class="toctree-l2 current current-page"><a class="current reference internal" href="#">PyTorch Extension</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
<p class="caption" role="heading"><span class="caption-text">Reference:</span></p>
|
||
<ul>
|
||
<li class="toctree-l1"><a class="reference external" href="https://github.com/NVIDIA/cutlass">Github</a></li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
|
||
</div>
|
||
|
||
</div>
|
||
</aside>
|
||
<div class="main">
|
||
<div class="content">
|
||
<div class="article-container">
|
||
<a href="#" class="back-to-top muted-link">
|
||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||
<path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8v12z"></path>
|
||
</svg>
|
||
<span>Back to top</span>
|
||
</a>
|
||
<div class="content-icon-container">
|
||
|
||
<div class="theme-toggle-container theme-toggle-content">
|
||
<button class="theme-toggle">
|
||
<div class="visually-hidden">Toggle Light / Dark / Auto color theme</div>
|
||
<svg class="theme-icon-when-auto"><use href="#svg-sun-half"></use></svg>
|
||
<svg class="theme-icon-when-dark"><use href="#svg-moon"></use></svg>
|
||
<svg class="theme-icon-when-light"><use href="#svg-sun"></use></svg>
|
||
</button>
|
||
</div>
|
||
<label class="toc-overlay-icon toc-content-icon" for="__toc">
|
||
<div class="visually-hidden">Toggle table of contents sidebar</div>
|
||
<i class="icon"><svg><use href="#svg-toc"></use></svg></i>
|
||
</label>
|
||
</div>
|
||
<article role="main">
|
||
<section id="Exporting-a-CUTLASS-grouped-GEMM-kernel-to-a-PyTorch-CUDA-extension">
|
||
<h1>Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension<a class="headerlink" href="#Exporting-a-CUTLASS-grouped-GEMM-kernel-to-a-PyTorch-CUDA-extension" title="Permalink to this heading">#</a></h1>
|
||
<p>This notebook walks through a basic example of using the CUTLASS Python interface to declare a grouped GEMM kernel and export it as a PyTorch CUDA extension.</p>
|
||
<p><a class="reference external" href="https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a></p>
|
||
<section id="Background-on-grouped-GEMM">
|
||
<h2>Background on grouped GEMM<a class="headerlink" href="#Background-on-grouped-GEMM" title="Permalink to this heading">#</a></h2>
|
||
<p>Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides) in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM, without the requirement that the sizes and strides of each GEMM be the same.</p>
|
||
<p>For example, if one has <code class="docutils literal notranslate"><span class="pre">p</span></code> GEMMs with sizes:</p>
|
||
<div class="highlight-text notranslate"><div class="highlight"><pre><span></span>M_1 x N_1 x K_1
|
||
M_2 x N_2 x K_2
|
||
...
|
||
M_p x N_p x K_p
|
||
</pre></div>
|
||
</div>
|
||
<p>CUTLASS’s grouped GEMM will execute these in a single CUDA kernel.</p>
|
||
<p>Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would insufficiently utilize the device in isolation.</p>
|
||
</section>
|
||
<section id="Declaring-a-grouped-GEMM-via-the-CUTLASS-Python-interface">
|
||
<h2>Declaring a grouped GEMM via the CUTLASS Python interface<a class="headerlink" href="#Declaring-a-grouped-GEMM-via-the-CUTLASS-Python-interface" title="Permalink to this heading">#</a></h2>
|
||
<p>A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one simply calls <code class="docutils literal notranslate"><span class="pre">cutlass.op.GroupedGemm</span></code>.</p>
|
||
<div class="nbinput docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[1]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">cutlass</span>
|
||
<span class="kn">import</span> <span class="nn">torch</span>
|
||
|
||
<span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span>
|
||
<span class="n">plan</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">op</span><span class="o">.</span><span class="n">GroupedGemm</span><span class="p">(</span><span class="n">element</span><span class="o">=</span><span class="n">dtype</span><span class="p">,</span> <span class="n">layout</span><span class="o">=</span><span class="n">cutlass</span><span class="o">.</span><span class="n">LayoutType</span><span class="o">.</span><span class="n">RowMajor</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="nboutput nblast docutils container">
|
||
<div class="prompt empty docutils container">
|
||
</div>
|
||
<div class="output_area stderr docutils container">
|
||
<div class="highlight"><pre>
|
||
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
|
||
from .autonotebook import tqdm as notebook_tqdm
|
||
</pre></div></div>
|
||
</div>
|
||
<p>We can then compile and run this operation on a group of GEMMs. We’ll first set up some utility functions to initialize GEMMs.</p>
|
||
<div class="nbinput nblast docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[2]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">random</span>
|
||
<span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">2023</span><span class="p">)</span>
|
||
|
||
<span class="c1"># Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K</span>
|
||
<span class="k">def</span> <span class="nf">initialize</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">):</span>
|
||
<span class="n">sizes</span> <span class="o">=</span> <span class="p">[(</span><span class="n">M</span><span class="p">,</span> <span class="n">K</span><span class="p">),</span> <span class="p">(</span><span class="n">K</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="p">(</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">)]</span>
|
||
<span class="k">return</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="n">size</span> <span class="ow">in</span> <span class="n">sizes</span><span class="p">]</span>
|
||
|
||
<span class="c1"># Utility function to generate `problems` GEMMs of random sizes</span>
|
||
<span class="k">def</span> <span class="nf">generate_problems</span><span class="p">(</span><span class="n">problems</span><span class="p">):</span>
|
||
<span class="n">valid_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">1024</span><span class="p">]</span>
|
||
<span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">problems</span><span class="p">):</span>
|
||
<span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span> <span class="o">=</span> <span class="p">[</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">valid_sizes</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">)]</span>
|
||
<span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">initialize</span><span class="p">(</span><span class="n">dtype</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span>
|
||
<span class="n">As</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
|
||
<span class="n">Bs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
|
||
<span class="n">Cs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">C</span><span class="p">)</span>
|
||
<span class="n">Ds</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">D</span><span class="p">)</span>
|
||
<span class="k">return</span> <span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<p>We’ll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch.</p>
|
||
<div class="nbinput docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[3]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span><span class="p">,</span> <span class="o">=</span> <span class="n">generate_problems</span><span class="p">(</span><span class="mi">50</span><span class="p">)</span>
|
||
|
||
<span class="n">plan</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">,</span> <span class="n">Cs</span><span class="p">,</span> <span class="n">Ds</span><span class="p">,</span> <span class="n">print_module</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
|
||
|
||
<span class="k">for</span> <span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">Ds</span><span class="p">,</span> <span class="n">Ds_torch</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="nboutput nblast docutils container">
|
||
<div class="prompt empty docutils container">
|
||
</div>
|
||
<div class="output_area docutils container">
|
||
<div class="highlight"><pre>
|
||
|
||
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8
|
||
using cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base =
|
||
typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
|
||
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
|
||
cutlass::half_t, cutlass::layout::RowMajor,
|
||
cutlass::half_t,
|
||
cutlass::arch::OpClassTensorOp,
|
||
cutlass::arch::Sm80,
|
||
cutlass::gemm::GemmShape<256, 128, 64>,
|
||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||
cutlass::gemm::GemmShape<16, 8, 16>,
|
||
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
|
||
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
|
||
3,
|
||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
||
cutlass::arch::OpMultiplyAdd
|
||
>::GemmKernel;
|
||
|
||
// Define named type
|
||
struct cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_type :
|
||
public cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base { };
|
||
|
||
</pre></div></div>
|
||
</div>
|
||
</section>
|
||
<section id="Exporting-the-CUTLASS-kernel-to-a-PyTorch-CUDA-extension">
|
||
<h2>Exporting the CUTLASS kernel to a PyTorch CUDA extension<a class="headerlink" href="#Exporting-the-CUTLASS-kernel-to-a-PyTorch-CUDA-extension" title="Permalink to this heading">#</a></h2>
|
||
<p>The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">PyTorch CUDA extension</a>. This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.</p>
|
||
<p>The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later “ahead-of-time” compilation, or be just-in-time compiled and returned to the user.</p>
|
||
<p>To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:</p>
|
||
<div class="nbinput nblast docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[4]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">op</span> <span class="o">=</span> <span class="n">plan</span><span class="o">.</span><span class="n">construct</span><span class="p">()</span>
|
||
<span class="n">grouped_gemm</span> <span class="o">=</span> <span class="n">cutlass</span><span class="o">.</span><span class="n">emit</span><span class="o">.</span><span class="n">pytorch</span><span class="p">(</span><span class="n">op</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'grouped_gemm'</span><span class="p">,</span> <span class="n">cc</span><span class="o">=</span><span class="n">plan</span><span class="o">.</span><span class="n">cc</span><span class="p">,</span> <span class="n">sourcedir</span><span class="o">=</span><span class="s1">'out'</span><span class="p">,</span> <span class="n">jit</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<p>The <code class="docutils literal notranslate"><span class="pre">cutlass.emit.pytorch</span></code> function emits: * <code class="docutils literal notranslate"><span class="pre">out/grouped_gemm_kernel.cu</span></code>: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors * <code class="docutils literal notranslate"><span class="pre">out/grouped_gemm.cpp</span></code>: This file contains a C++ wrapper around the aforementioned CUTLASS kernel * <code class="docutils literal notranslate"><span class="pre">setup.py</span></code>: This file contains the <code class="docutils literal notranslate"><span class="pre">setuptools</span></code> script for building and installing the generated extension</p>
|
||
<p>The extension can be build from within the <code class="docutils literal notranslate"><span class="pre">module_output</span></code> directory by running:</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nv">TORCH_CUDA_ARCH_LIST</span><span class="o">=</span><span class="s2">"8.0"</span><span class="w"> </span>python<span class="w"> </span>setup.py<span class="w"> </span>install
|
||
</pre></div>
|
||
</div>
|
||
<p>Where <code class="docutils literal notranslate"><span class="pre">TORCH_ARCH_LIST</span></code> is set to the compute capability of the device on which the kernel will be run.</p>
|
||
<p>See the PyTorch <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">“Custom C++ and CUDA Extensions”</a> tutorial for more details on this.</p>
|
||
<p>The PyTorch CUDA extension could be built for this module by running:</p>
|
||
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span><span class="nb">cd</span><span class="w"> </span>out
|
||
<span class="nv">TORCH_CUDA_ARCH_LIST</span><span class="o">=</span><span class="s2">"8.0"</span><span class="w"> </span>python<span class="w"> </span>setup.py
|
||
</pre></div>
|
||
</div>
|
||
<p>(assuming that one is building for SM80)</p>
|
||
<p>One could then use the kernel in a later PyTorch module by running:</p>
|
||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
|
||
<span class="kn">import</span> <span class="nn">grouped_gemm</span>
|
||
|
||
<span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
<p>In this case, however, we set <code class="docutils literal notranslate"><span class="pre">jit=True</span></code>, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly. Under the hood, this leverages the <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">torch.utils.cpp_extension.load</a> method and returns back the loaded extension.</p>
|
||
<p>We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:</p>
|
||
<div class="nbinput nblast docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[5]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">Ds</span> <span class="o">=</span> <span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
|
||
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
|
||
<span class="k">for</span> <span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">Ds</span><span class="p">,</span> <span class="n">Ds_torch</span><span class="p">):</span>
|
||
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">d</span><span class="p">,</span> <span class="n">d_torch</span><span class="p">)</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<p>Finally, we can profile our grouped GEMM extension:</p>
|
||
<div class="nbinput docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[6]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">num_warmup</span> <span class="o">=</span> <span class="mi">20</span>
|
||
<span class="n">num_profile</span> <span class="o">=</span> <span class="mi">100</span>
|
||
|
||
<span class="c1"># Warmup iterations</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_warmup</span><span class="p">):</span>
|
||
<span class="n">Ds</span> <span class="o">=</span> <span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
|
||
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
|
||
<span class="c1"># Timing iterations</span>
|
||
<span class="kn">import</span> <span class="nn">time</span>
|
||
<span class="n">grouped</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="n">nongrouped</span> <span class="o">=</span> <span class="mi">0</span>
|
||
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_profile</span><span class="p">):</span>
|
||
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||
<span class="n">Ds</span> <span class="o">=</span> <span class="n">grouped_gemm</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
<span class="n">grouped</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
|
||
|
||
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
|
||
<span class="n">Ds_torch</span> <span class="o">=</span> <span class="p">[</span><span class="n">a</span> <span class="o">@</span> <span class="n">b</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">As</span><span class="p">,</span> <span class="n">Bs</span><span class="p">)]</span>
|
||
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
|
||
<span class="n">nongrouped</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
|
||
|
||
<span class="nb">print</span><span class="p">(</span><span class="s1">'Grouped: </span><span class="si">{:.3f}</span><span class="s1"> us'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">grouped</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="n">num_profile</span><span class="p">))</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="s1">'Non-Grouped: </span><span class="si">{:.3f}</span><span class="s1"> us'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">nongrouped</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="n">num_profile</span><span class="p">))</span>
|
||
<span class="nb">print</span><span class="p">(</span><span class="s1">'Speedup: </span><span class="si">{:.3f}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">nongrouped</span> <span class="o">/</span> <span class="n">grouped</span><span class="p">))</span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
<div class="nboutput nblast docutils container">
|
||
<div class="prompt empty docutils container">
|
||
</div>
|
||
<div class="output_area docutils container">
|
||
<div class="highlight"><pre>
|
||
Grouped: 400.696 us
|
||
Non-Grouped: 646.670 us
|
||
Speedup: 1.614
|
||
</pre></div></div>
|
||
</div>
|
||
<div class="nbinput nblast docutils container">
|
||
<div class="prompt highlight-none notranslate"><div class="highlight"><pre><span></span>[ ]:
|
||
</pre></div>
|
||
</div>
|
||
<div class="input_area highlight-ipython3 notranslate"><div class="highlight"><pre><span></span>
|
||
</pre></div>
|
||
</div>
|
||
</div>
|
||
</section>
|
||
</section>
|
||
|
||
</article>
|
||
</div>
|
||
<footer>
|
||
|
||
<div class="related-pages">
|
||
|
||
<a class="prev-page" href="01_epilogue.html">
|
||
<svg class="furo-related-icon"><use href="#svg-arrow-right"></use></svg>
|
||
<div class="page-info">
|
||
<div class="context">
|
||
<span>Previous</span>
|
||
</div>
|
||
|
||
<div class="title">Example of using elementwise activation functions in the CUTLASS Python interface</div>
|
||
|
||
</div>
|
||
</a>
|
||
</div>
|
||
<div class="bottom-of-page">
|
||
<div class="left-details">
|
||
<div class="copyright">
|
||
Copyright © 2023, NVIDIA
|
||
</div>
|
||
Made with <a href="https://www.sphinx-doc.org/">Sphinx</a> and <a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>'s
|
||
|
||
<a href="https://github.com/pradyunsg/furo">Furo</a>
|
||
|
||
</div>
|
||
<div class="right-details">
|
||
<div class="icons">
|
||
<a class="muted-link " href="https://github.com/NVIDIA/cutlass" aria-label="GitHub">
|
||
<svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
|
||
<path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
|
||
</svg>
|
||
</a>
|
||
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
</footer>
|
||
</div>
|
||
<aside class="toc-drawer">
|
||
|
||
|
||
<div class="toc-sticky toc-scroll">
|
||
<div class="toc-title-container">
|
||
<span class="toc-title">
|
||
On this page
|
||
</span>
|
||
</div>
|
||
<div class="toc-tree-container">
|
||
<div class="toc-tree">
|
||
<ul>
|
||
<li><a class="reference internal" href="#">Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension</a><ul>
|
||
<li><a class="reference internal" href="#Background-on-grouped-GEMM">Background on grouped GEMM</a></li>
|
||
<li><a class="reference internal" href="#Declaring-a-grouped-GEMM-via-the-CUTLASS-Python-interface">Declaring a grouped GEMM via the CUTLASS Python interface</a></li>
|
||
<li><a class="reference internal" href="#Exporting-the-CUTLASS-kernel-to-a-PyTorch-CUDA-extension">Exporting the CUTLASS kernel to a PyTorch CUDA extension</a></li>
|
||
</ul>
|
||
</li>
|
||
</ul>
|
||
|
||
</div>
|
||
</div>
|
||
</div>
|
||
|
||
|
||
</aside>
|
||
</div>
|
||
</div><script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
|
||
<script src="../_static/doctools.js"></script>
|
||
<script src="../_static/sphinx_highlight.js"></script>
|
||
<script src="../_static/scripts/furo.js"></script>
|
||
<script src="../_static/clipboard.min.js"></script>
|
||
<script src="../_static/copybutton.js"></script>
|
||
<script src="../_static/tabs.js"></script>
|
||
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
|
||
<script>window.MathJax = {"tex": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true}, "options": {"ignoreHtmlClass": "tex2jax_ignore|mathjax_ignore|document", "processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
|
||
<script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||
</body>
|
||
</html> |