cutlass/python/cutlass_library/library.py

1041 lines
34 KiB
Python

#################################################################################################
#
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Data types and tags used for emitting CUTLASS C++ kernels
"""
import enum
import re
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
# as the default 3.5.2 on Ubuntu 16.04.
#
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
try:
from enum import auto as enum_auto
except ImportError:
__cutlass_library_auto_enum = 0
def enum_auto() -> int:
global __cutlass_library_auto_enum
i = __cutlass_library_auto_enum
__cutlass_library_auto_enum += 1
return i
###################################################################################################
#
class GeneratorTarget(enum.Enum):
Library = enum_auto()
#
GeneratorTargetNames = {
GeneratorTarget.Library: 'library'
}
#
###################################################################################################
#
class DataType(enum.Enum):
void = enum_auto() # primarily used to disable C tensor for epilogues
b1 = enum_auto()
u4 = enum_auto()
u8 = enum_auto()
u16 = enum_auto()
u32 = enum_auto()
u64 = enum_auto()
s4 = enum_auto()
s8 = enum_auto()
s16 = enum_auto()
s32 = enum_auto()
s64 = enum_auto()
e4m3 = enum_auto()
e5m2 = enum_auto()
f16 = enum_auto()
bf16 = enum_auto()
f32 = enum_auto()
tf32 = enum_auto()
f64 = enum_auto()
cf16 = enum_auto()
cbf16 = enum_auto()
cf32 = enum_auto()
ctf32 = enum_auto()
cf64 = enum_auto()
cs4 = enum_auto()
cs8 = enum_auto()
cs16 = enum_auto()
cs32 = enum_auto()
cs64 = enum_auto()
cu4 = enum_auto()
cu8 = enum_auto()
cu16 = enum_auto()
cu32 = enum_auto()
cu64 = enum_auto()
invalid = enum_auto()
#
ShortDataTypeNames = {
DataType.s32: 'i',
DataType.e4m3: 'e4m3',
DataType.e5m2: 'e5m2',
DataType.f16: 'h',
DataType.f32: 's',
DataType.f64: 'd',
DataType.cf32: 'c',
DataType.cf64: 'z',
}
#
DataTypeNames = {
DataType.void: "void",
DataType.b1: "b1",
DataType.u4: "u4",
DataType.u8: "u8",
DataType.u16: "u16",
DataType.u32: "u32",
DataType.u64: "u64",
DataType.s4: "s4",
DataType.s8: "s8",
DataType.s16: "s16",
DataType.s32: "s32",
DataType.s64: "s64",
DataType.e4m3: 'e4m3',
DataType.e5m2: 'e5m2',
DataType.f16: "f16",
DataType.bf16: "bf16",
DataType.f32: "f32",
DataType.tf32: "tf32",
DataType.f64: "f64",
DataType.cf16: "cf16",
DataType.cbf16: "cbf16",
DataType.cf32: "cf32",
DataType.ctf32: "ctf32",
DataType.cf64: "cf64",
DataType.cu4: "cu4",
DataType.cu8: "cu8",
DataType.cu16: "cu16",
DataType.cu32: "cu32",
DataType.cu64: "cu64",
DataType.cs4: "cs4",
DataType.cs8: "cs8",
DataType.cs16: "cs16",
DataType.cs32: "cs32",
DataType.cs64: "cs64",
}
DataTypeTag = {
DataType.void: "void",
DataType.b1: "cutlass::uint1b_t",
DataType.u4: "cutlass::uint4b_t",
DataType.u8: "uint8_t",
DataType.u16: "uint16_t",
DataType.u32: "uint32_t",
DataType.u64: "uint64_t",
DataType.s4: "cutlass::int4b_t",
DataType.s8: "int8_t",
DataType.s16: "int16_t",
DataType.s32: "int32_t",
DataType.s64: "int64_t",
DataType.e4m3: 'cutlass::float_e4m3_t',
DataType.e5m2: 'cutlass::float_e5m2_t',
DataType.f16: "cutlass::half_t",
DataType.bf16: "cutlass::bfloat16_t",
DataType.f32: "float",
DataType.tf32: "cutlass::tfloat32_t",
DataType.f64: "double",
DataType.cf16: "cutlass::complex<cutlass::half_t>",
DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
DataType.cf32: "cutlass::complex<float>",
DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
DataType.cf64: "cutlass::complex<double>",
DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
DataType.cs8: "cutlass::complex<cutlass::int8_t>",
DataType.cs16: "cutlass::complex<cutlass::int16_t>",
DataType.cs32: "cutlass::complex<cutlass::int32_t>",
DataType.cs64: "cutlass::complex<cutlass::int64_t>",
}
DataTypeSize = {
DataType.void: 0,
DataType.b1: 1,
DataType.u4: 4,
DataType.u8: 8,
DataType.u16: 16,
DataType.u32: 32,
DataType.u64: 64,
DataType.s4: 4,
DataType.s8: 8,
DataType.s16: 16,
DataType.s32: 32,
DataType.s64: 64,
DataType.e4m3: 8,
DataType.e5m2: 8,
DataType.f16: 16,
DataType.bf16: 16,
DataType.f32: 32,
DataType.tf32: 32,
DataType.f64: 64,
DataType.cf16: 32,
DataType.cbf16: 32,
DataType.cf32: 64,
DataType.ctf32: 32,
DataType.cf64: 128,
DataType.cu4: 8,
DataType.cu8: 16,
DataType.cu16: 32,
DataType.cu32: 64,
DataType.cu64: 128,
DataType.cs4: 8,
DataType.cs8: 16,
DataType.cs16: 32,
DataType.cs32: 64,
DataType.cs64: 128,
}
###################################################################################################
#
class BlasMode(enum.Enum):
symmetric = enum_auto()
hermitian = enum_auto()
#
BlasModeTag = {
BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric',
BlasMode.hermitian: 'cutlass::BlasMode::kHermitian',
}
#
class ComplexTransform(enum.Enum):
none = enum_auto()
conj = enum_auto()
#
ComplexTransformTag = {
ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
}
# Used for cutlass3x complex kernel collective mainloop builder instantiation
ComplexTransformTag3x = {
ComplexTransform.none: 'cute::identity',
ComplexTransform.conj: 'cute::conjugate',
}
#
RealComplexBijection = [
(DataType.f16, DataType.cf16),
(DataType.f32, DataType.cf32),
(DataType.f64, DataType.cf64),
]
#
def is_complex(data_type):
for r, c in RealComplexBijection:
if data_type == c:
return True
return False
#
def get_complex_from_real(real_type):
for r, c in RealComplexBijection:
if real_type == r:
return c
return DataType.invalid
#
def get_real_from_complex(complex_type):
for r, c in RealComplexBijection:
if complex_type == c:
return r
return DataType.invalid
#
class ComplexMultiplyOp(enum.Enum):
multiply_add = enum_auto()
gaussian = enum_auto()
###################################################################################################
#
class MathOperation(enum.Enum):
multiply_add = enum_auto()
multiply_add_saturate = enum_auto()
multiply_add_mixed_input_upcast = enum_auto()
xor_popc = enum_auto()
and_popc = enum_auto()
multiply_add_fast_bf16 = enum_auto()
multiply_add_fast_f16 = enum_auto()
multiply_add_fast_f32 = enum_auto()
multiply_add_complex_fast_f32 = enum_auto()
multiply_add_complex = enum_auto()
multiply_add_complex_gaussian = enum_auto()
multiply_add_fast_accum = enum_auto()
#
MathOperationTag = {
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast',
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16',
MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32',
MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32',
MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex',
MathOperation.multiply_add_fast_accum: 'cutlass::arch::OpMultiplyAddFastAccum',
}
###################################################################################################
#
class LayoutType(enum.Enum):
ColumnMajor = enum_auto()
RowMajor = enum_auto()
ColumnMajorInterleaved2 = enum_auto()
RowMajorInterleaved2 = enum_auto()
ColumnMajorInterleaved32 = enum_auto()
RowMajorInterleaved32 = enum_auto()
ColumnMajorInterleaved64 = enum_auto()
RowMajorInterleaved64 = enum_auto()
TensorNWC = enum_auto()
TensorNHWC = enum_auto()
TensorNDHWC = enum_auto()
TensorNCHW = enum_auto()
TensorNGHWC = enum_auto()
TensorNC32HW32 = enum_auto()
TensorNC64HW64 = enum_auto()
TensorC32RSK32 = enum_auto()
TensorC64RSK64 = enum_auto()
TensorKCS = enum_auto()
TensorKCSR = enum_auto()
TensorKCSRT = enum_auto()
#
LayoutTag = {
LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor',
LayoutType.RowMajor: 'cutlass::layout::RowMajor',
LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>',
LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>',
LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>',
LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>',
LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>',
LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>',
LayoutType.TensorNWC: 'cutlass::layout::TensorNWC',
LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC',
LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC',
LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW',
LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC',
LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>',
LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>',
LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>',
LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>',
LayoutType.TensorKCS: 'cutlass::layout::TensorKCS',
LayoutType.TensorKCSR: 'cutlass::layout::TensorKCSR',
LayoutType.TensorKCSRT: 'cutlass::layout::TensorKCSRT'
}
#
TransposedLayout = {
LayoutType.ColumnMajor: LayoutType.RowMajor,
LayoutType.RowMajor: LayoutType.ColumnMajor,
LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
LayoutType.TensorNHWC: LayoutType.TensorNHWC
}
#
ShortLayoutTypeNames = {
LayoutType.ColumnMajor: 'n',
LayoutType.ColumnMajorInterleaved2: 'n2',
LayoutType.ColumnMajorInterleaved32: 'n32',
LayoutType.ColumnMajorInterleaved64: 'n64',
LayoutType.RowMajor: 't',
LayoutType.RowMajorInterleaved2: 't2',
LayoutType.RowMajorInterleaved32: 't32',
LayoutType.RowMajorInterleaved64: 't64',
LayoutType.TensorNWC: 'nwc',
LayoutType.TensorNHWC: 'nhwc',
LayoutType.TensorNDHWC: 'ndhwc',
LayoutType.TensorNCHW: 'nchw',
LayoutType.TensorNGHWC: 'nghwc',
LayoutType.TensorNC32HW32: 'nc32hw32',
LayoutType.TensorNC64HW64: 'nc64hw64',
LayoutType.TensorC32RSK32: 'c32rsk32',
LayoutType.TensorC64RSK64: 'c64rsk64',
LayoutType.TensorKCS: 'kcs',
LayoutType.TensorKCSR: 'kcsr',
LayoutType.TensorKCSRT: 'kcsrt'
}
#
ShortComplexLayoutNames = {
(LayoutType.ColumnMajor, ComplexTransform.none): 'n',
(LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
(LayoutType.RowMajor, ComplexTransform.none): 't',
(LayoutType.RowMajor, ComplexTransform.conj): 'h'
}
###################################################################################################
class KernelScheduleType(enum.Enum):
ScheduleAuto = enum_auto()
Multistage = enum_auto()
CpAsyncWarpSpecialized = enum_auto()
CpAsyncWarpSpecializedPingpong = enum_auto()
CpAsyncWarpSpecializedCooperative = enum_auto()
Tma = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedPingpong = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
TmaWarpSpecializedFP8FastAccum = enum_auto()
TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
ImplicitTmaWarpSpecializedSm90 = enum_auto()
#
KernelScheduleTag = {
KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized',
KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong',
KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative',
KernelScheduleType.Tma: 'cutlass::gemm::KernelTma',
KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized',
KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong',
KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative',
KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum',
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum',
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
}
#
KernelScheduleSuffixes = {
KernelScheduleType.ScheduleAuto: '',
KernelScheduleType.Multistage: '_cpasync',
KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized',
KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong',
KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative',
KernelScheduleType.Tma: '_unspecialized',
KernelScheduleType.TmaWarpSpecialized: '_warpspecialized',
KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum',
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized',
}
class EpilogueScheduleType(enum.Enum):
ScheduleAuto = enum_auto()
EpilogueTransposed = enum_auto()
NoSmemWarpSpecialized = enum_auto()
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
#
EpilogueScheduleTag = {
EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto',
EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed',
EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
}
#
EpilogueScheduleSuffixes = {
EpilogueScheduleType.ScheduleAuto: '',
EpilogueScheduleType.EpilogueTransposed: '',
EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
}
class EpilogueFunctor3x(enum.Enum):
LinearCombination = enum_auto()
#
EpilogueFunctor3xTag = {
EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
}
class TileSchedulerType(enum.Enum):
Default = enum_auto()
Persistent = enum_auto()
StreamK = enum_auto()
#
TileSchedulerTag = {
TileSchedulerType.Default: 'void',
TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler',
TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler',
}
#
TileSchedulerSuffixes = {
TileSchedulerType.Default: '',
TileSchedulerType.Persistent: '',
TileSchedulerType.StreamK: '_stream_k',
}
###################################################################################################
#
class SideMode(enum.Enum):
Left = enum_auto()
Right = enum_auto()
#
SideModeTag = {
SideMode.Left: 'cutlass::SideMode::kLeft',
SideMode.Right: 'cutlass::SideMode::kRight'
}
#
ShortSideModeNames = {
SideMode.Left: 'ls',
SideMode.Right: 'rs'
}
###################################################################################################
#
class FillMode(enum.Enum):
Lower = enum_auto()
Upper = enum_auto()
#
FillModeTag = {
FillMode.Lower: 'cutlass::FillMode::kLower',
FillMode.Upper: 'cutlass::FillMode::kUpper'
}
#
ShortFillModeNames = {
FillMode.Lower: 'l',
FillMode.Upper: 'u'
}
###################################################################################################
#
class DiagType(enum.Enum):
NonUnit = enum_auto()
Unit = enum_auto()
#
DiagTypeTag = {
DiagType.NonUnit: 'cutlass::DiagType::kNonUnit',
DiagType.Unit: 'cutlass::DiagType::kUnit'
}
#
ShortDiagTypeNames = {
DiagType.NonUnit: 'nu',
DiagType.Unit: 'un'
}
###################################################################################################
#
class OpcodeClass(enum.Enum):
Simt = enum_auto()
TensorOp = enum_auto()
WmmaTensorOp = enum_auto()
SparseTensorOp = enum_auto()
OpcodeClassNames = {
OpcodeClass.Simt: 'simt',
OpcodeClass.TensorOp: 'tensorop',
OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
}
OpcodeClassTag = {
OpcodeClass.Simt: 'cutlass::arch::OpClassSimt',
OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
}
###################################################################################################
#
class OperationKind(enum.Enum):
Gemm = enum_auto()
RankK = enum_auto()
Rank2K = enum_auto()
Trmm = enum_auto()
Symm = enum_auto()
Conv2d = enum_auto()
Conv3d = enum_auto()
#
OperationKindNames = {
OperationKind.Gemm: 'gemm'
, OperationKind.RankK: 'rank_k'
, OperationKind.Rank2K: 'rank_2k'
, OperationKind.Trmm: 'trmm'
, OperationKind.Symm: 'symm'
, OperationKind.Conv2d: 'conv2d'
, OperationKind.Conv3d: 'conv3d'
}
#
class Target(enum.Enum):
library = enum_auto()
#
ArchitectureNames = {
50: 'maxwell',
60: 'pascal',
61: 'pascal',
70: 'volta',
75: 'turing',
80: 'ampere',
89: 'ada',
90: 'hopper'
}
#
SharedMemPerCC = {
70: 96, # 96KB of SMEM
72: 96, # 96KB of SMEM
75: 64, # 64KB of SMEM
80: 163, # 163KB of SMEM - 1KB reserved for the driver
86: 99, # 99KB of SMEM - 1KB reserved for the driver
87: 163, # 163KB of SMEM - 1KB reserved for the driver
89: 99, # 99KB of SMEM - 1KB reserved for the driver
90: 227, # 227KB of SMEM - 1KB reserved for the driver
}
###################################################################################################
#
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
###################################################################################################
#
class GemmKind(enum.Enum):
Gemm = enum_auto()
Sparse = enum_auto()
Universal = enum_auto()
Universal3x = enum_auto()
SparseUniversal3x = enum_auto()
PlanarComplex = enum_auto()
PlanarComplexArray = enum_auto()
Grouped = enum_auto()
#
GemmKindNames = {
GemmKind.Gemm: "gemm",
GemmKind.Sparse: "spgemm",
GemmKind.Universal: "gemm",
GemmKind.Universal3x: "gemm",
GemmKind.SparseUniversal3x: "spgemm",
GemmKind.PlanarComplex: "gemm_planar_complex",
GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
GemmKind.Grouped: "gemm_grouped",
}
#
class RankKKind(enum.Enum):
Universal = enum_auto()
#
RankKKindNames = {
RankKKind.Universal: "rank_k"
}
#
class TrmmKind(enum.Enum):
Universal = enum_auto()
#
TrmmKindNames = {
TrmmKind.Universal: "trmm"
}
#
class SymmKind(enum.Enum):
Universal = enum_auto()
#
SymmKindNames = {
SymmKind.Universal: "symm"
}
#
class EpilogueFunctor(enum.Enum):
LinearCombination = enum_auto()
LinearCombinationClamp = enum_auto()
#
EpilogueFunctorTag = {
EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
}
#
class SwizzlingFunctor(enum.Enum):
Identity1 = enum_auto()
Identity2 = enum_auto()
Identity4 = enum_auto()
Identity8 = enum_auto()
Horizontal = enum_auto()
StridedDgradIdentity1 = enum_auto()
StridedDgradIdentity4 = enum_auto()
StridedDgradHorizontal = enum_auto()
StreamK = enum_auto()
#
SwizzlingFunctorTag = {
SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>',
SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>',
SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>',
SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle',
SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>',
SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>',
SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle',
SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK',
}
#
class GroupScheduleMode(enum.Enum):
Device = enum_auto(),
Host = enum_auto()
#
GroupScheduleModeTag = {
GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly',
GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute'
}
#
ShortGroupScheduleModeNames = {
GroupScheduleMode.Device: 'Device',
GroupScheduleMode.Host: 'Host'
}
###################################################################################################
#
class ConvKind(enum.IntEnum):
Fprop = 0
Dgrad = 1
Wgrad = 2
#
ConvKindTag = {
ConvKind.Fprop: 'cutlass::conv::Operator::kFprop',
ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad',
ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad'
}
ConvKindNames = {
ConvKind.Fprop: 'fprop',
ConvKind.Dgrad: 'dgrad',
ConvKind.Wgrad: 'wgrad',
}
class ConvMode(enum.IntEnum):
CrossCorrelation = 0
Convolution = 1
#
class IteratorAlgorithm(enum.Enum):
Analytic = 0
Optimized = 1
FixedChannels = 2
FewChannels = 3
FixedStrideDilation = 4
#
IteratorAlgorithmTag = {
IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels',
IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels',
IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation'
}
IteratorAlgorithmNames = {
IteratorAlgorithm.Analytic: 'analytic',
IteratorAlgorithm.Optimized: 'optimized',
IteratorAlgorithm.FixedChannels: 'fixed_channels',
IteratorAlgorithm.FewChannels: 'few_channels',
IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation'
}
#
class StrideSupport(enum.Enum):
Strided = 0
Unity = 1
Fixed = 2
#
StrideSupportTag = {
StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed'
}
StrideSupportNames = {
StrideSupport.Strided: '',
StrideSupport.Unity: 'unity_stride',
StrideSupport.Fixed: 'fixed_stride'
}
#
class GroupMode(enum.Enum):
NoneGroup = enum_auto() # dense conv (G=1)
SingleGroup = enum_auto() # grouped convolution (single group per CTA)
MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA)
Depthwise = enum_auto() # Depthwise convolution ( C=K=G )
#
GroupModeTag = {
GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone',
GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup',
GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup',
GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise',
}
GroupModeNames = {
GroupMode.NoneGroup: '',
GroupMode.SingleGroup: 'single_group',
GroupMode.MultipleGroup: 'multiple_group',
GroupMode.Depthwise: 'depthwise',
}
###################################################################################################
#
class MathInstruction:
def __init__(self,
instruction_shape, \
element_a, element_b, element_accumulator, \
opcode_class, math_operation = MathOperation.multiply_add \
):
self.instruction_shape = instruction_shape
self.element_a = element_a
self.element_b = element_b
self.element_accumulator = element_accumulator
self.opcode_class = opcode_class
self.math_operation = math_operation
#
class TileDescription:
def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1]):
self.threadblock_shape = threadblock_shape
self.tile_shape = threadblock_shape
self.stages = stages
self.warp_count = warp_count
self.math_instruction = math_instruction
self.minimum_compute_capability = min_compute
self.maximum_compute_capability = max_compute
self.cluster_shape = cluster_shape
def procedural_name(self):
if self.minimum_compute_capability >= 90:
return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format(
tbm = self.threadblock_shape[0],
tbn = self.threadblock_shape[1],
tbk = self.threadblock_shape[2],
cm = self.cluster_shape[0],
cn = self.cluster_shape[1],
ck = self.cluster_shape[2],
s = self.stages)
else:
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
#
class Direct2dConvFixedStrideDilationTileDescription:
def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
self.threadblock_output_shape = threadblock_output_shape
self.filter_shape = filter_shape
self.stages = stages
self.warp_count = warp_count
self.stride = stride
self.dilation = dilation
self.math_instruction = math_instruction
self.minimum_compute_capability = min_compute
self.maximum_compute_capability = max_compute
def procedural_name(self):
str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
self.threadblock_shape[1],
self.threadblock_shape[2],
self.threadblock_output_shape[0],
self.threadblock_output_shape[1],
self.threadblock_output_shape[2],
self.threadblock_output_shape[3],
self.stages,
self.filter_shape[0],
self.filter_shape[1])
# Fixed Strided and dilation
if self.stride != [-1, -1] and self.dilation != [-1, -1]:
str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
self.stride[1],
self.dilation[0],
self.dilation[1])
return str_name
#
class Direct2dConvFixedStrideDilationTileDescription:
def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
self.threadblock_output_shape = threadblock_output_shape
self.filter_shape = filter_shape
self.stages = stages
self.warp_count = warp_count
self.stride = stride
self.dilation = dilation
self.math_instruction = math_instruction
self.minimum_compute_capability = min_compute
self.maximum_compute_capability = max_compute
def procedural_name(self):
str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
self.threadblock_shape[1],
self.threadblock_shape[2],
self.threadblock_output_shape[0],
self.threadblock_output_shape[1],
self.threadblock_output_shape[2],
self.threadblock_output_shape[3],
self.stages,
self.filter_shape[0],
self.filter_shape[1])
# Fixed Strided and dilation
if self.stride != [-1, -1] and self.dilation != [-1, -1]:
str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
self.stride[1],
self.dilation[0],
self.dilation[1])
return str_name
#
class TensorDescription:
def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
self.element = element
self.layout = layout
self.alignment = alignment
self.complex_transform = complex_transform
#
class SymmetricTensorDescription:
def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left):
self.element = element
self.layout = layout
self.fill_mode = fill_mode
self.alignment = alignment
self.complex_transform = complex_transform
self.side_mode = side_mode
#
class TriangularTensorDescription:
def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none):
self.element = element
self.layout = layout
self.side_mode = side_mode
self.fill_mode = fill_mode
self.diag_type = diag_type
self.alignment = alignment
self.complex_transform = complex_transform
#
def CalculateSmemUsage(operation):
cta_shape = operation.tile_description.threadblock_shape
stages = operation.tile_description.stages
if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse:
# Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity)
if DataTypeSize[operation.A.element] == 32:
elements_per_8b_md = 2
elif DataTypeSize[operation.A.element] == 4:
elements_per_8b_md = 8
else:
elements_per_8b_md = 4
smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \
DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \
cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md
else:
# Few BLAS3 operations only have A tensor
data_type_size_a = DataTypeSize[operation.A.element]
data_type_size_b = DataTypeSize[operation.A.element]
if operation.is_mixed_input():
data_type_size_b = DataTypeSize[operation.B.element]
smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \
data_type_size_b * cta_shape[1] * cta_shape[2] // 8
smem_usage = smem_per_stage * stages
return (smem_usage >> 10)
class GemmUniversalMode(enum.IntEnum):
"""
Types corresponding to GemmUniversalMode
"""
Gemm = 0
GemmSplitKParallel = 1
Batched = 2
Array = 3
class SplitKMode(enum.IntEnum):
"""
Types corresponding to SplitKMode
"""
NoneSplitK = 0
Serial = 1
Parallel = 2