cutlass/examples/39_gemm_permute/gemm_permute.cu

1224 lines
47 KiB
Plaintext

/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief GEMM Permute Example.
This example computes batched GEMM operations with output results permuted as reshaped tensors.
We provide layout plugin as a flexible tool for users to add any customized input/output tensor permute operation,
or any other generalized global memory writeout address computation. To add a customized layout, add new class
in include/cutlass/layout/permute.h
In this example we use several permute operations (permute([0, 2, 1, 3]))
In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM
whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on
output matrix. The address computations are performed in compute(col_init, row_init, stride_init,
BMM_batch_idx) with {col_permute, row_permute and stride_permute} as new addresses after permute op.
(check include/cutlass/layout/permute.h)
Tips:
1) Make sure to set batch_stride to zero for BMM permute; also the BMM GEMM should be in mode
cutlass::gemm::GemmUniversalMode::kBatched instead of kArray.
2) When the contiguous dimension is touched in permute op (for example [0, 2, 3, 1] for row-major matrix
or [1, 0, 2, 3] for column-major), Alignment should be set to 1 for the corresponding matrix.
If the last dimension is untouched, one can set Alignment to be larger like 8 in our example.
As a result, permute op without touching the unit stride dimension is recommended to obtain the best performance.
Examples:
# Runs a batched GEMM with 96 batches
$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96
# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)
$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true
# Execute batched GEMM and profile with NSight
$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false
*/
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <map>
#include <unordered_map>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm_complex.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/layout/permute.h"
#include "layouts.h"
#include "permute_info.h"
/// Tensor4DPermuteBMM0213 --->
/// Permute layout function for 4-D permuted tensors for BMM with BMM tensor (dimension as [B, M, N]) reshaped
/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM tensor.
int constexpr D1 = 12;
/// Tensor5DPermute20314 --->
/// Permute layout function for 5-D permuted tensors with matrix (dimension as [M, N]) reshaped
/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding tensor.
int constexpr T1 = 16;
int constexpr T2 = 3;
int constexpr T3 = 8;
/// Tensor4DPermute0213 --->
/// Permute layout function for 4-D permuted tensors with matrix (dimension as [M, N]) reshaped
/// as [M/S1, S1, S2, N/S2]. Then perform permute([0, 2, 1, 3]) on the corresponding tensor.
int constexpr S1 = 8;
int constexpr S2 = 4;
// // // Alignments
int constexpr AlignmentA = 8;
int constexpr AlignmentB = 8;
int constexpr AlignmentC = 8;
/// GEMM element types
using ElementInput = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Useful macros
#define CHECK_CUDA_CALL(call, handler) \
do { \
cudaError_t __err = (call); \
if (__err != cudaSuccess) { \
std::cerr << #call " failed: " << cudaGetErrorString(__err) << std::endl; \
handler; \
} \
} while(0)
#define CHECK_CUTLASS_CALL(call, handler) \
do { \
cutlass::Status __status = (call); \
if (__status != cutlass::Status::kSuccess) { \
std::cerr << #call " failed: " << cutlass::cutlassGetStatusString(__status) << std::endl; \
handler; \
} \
} while(0)
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
bool error;
bool reference_check;
cutlass::gemm::GemmCoord problem_each;
int batch_count;
int iterations;
int cuda_streams;
bool verbose;
float alpha;
float beta;
//
// Methods
//
Options():
help(false),
error(false),
reference_check(true),
batch_count(-1),
iterations(20),
cuda_streams(0),
verbose(false),
alpha(1),
beta()
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("alpha", alpha, 1.0f);
cmd.get_cmd_line_argument("beta", beta, 0.0f);
cmd.get_cmd_line_argument("iterations", iterations, 20);
cmd.get_cmd_line_argument("streams", cuda_streams, 0);
cmd.get_cmd_line_argument("verbose", verbose, false);
cmd.get_cmd_line_argument("reference-check", reference_check, true);
int m, n, k;
cmd.get_cmd_line_argument("m", m, 384);
cmd.get_cmd_line_argument("n", n, 192);
cmd.get_cmd_line_argument("k", k, 384);
cmd.get_cmd_line_argument("batch-count", batch_count, 96);
problem_each = cutlass::gemm::GemmCoord(m, n, k);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out <<
"39_gemm_permute\n"
"\n"
" This example tests and profiles the performance of normal GEMM and batched GEMM with different"
" combinations of fused permutations of input and output tensors."
"\n"
" Permutations considered in this example:\n"
"\n"
" Normal GEMM:\n"
" 1) Tensor4DPermute0213: matrix of shape [X, Y] is reshaped as [X/S1, S1, S2, Y/S2] and has its dimensions"
" permuted as [0, 2, 1, 3], resulting in shape [X/S1, S2, S1, Y/S2] viewed as matrix of shape [X*S2/S1, Y*S1/S2].\n"
" 2) Tensor5DPermute20314: matrix of shape [X, Y] is reshaped as [X/T1, T1, T2, T3, Y/T2/T3] and has its dimensions"
" permuted as [2, 0, 3, 1, 4], resulting in shape [T2, X/T1, T3, T1, Y/T2/T3] viewed as matrix of shape [X*T2/T1, Y*T1/T2].\n"
"\n"
" Batched GEMM:\n"
" 3) Tensor4DPermuteBMM0213: batched tensor of 3D shape [B, X, Y] is reshaped as 4D shape [B/D1, D1, X, Y]"
" and has its dimensions permuted as [0, 2, 1, 3], resulting in shape [B/D1, X, D1, Y] viewed as"
" a matrix of shape [B/D1, X, Y*D1] for batched GEMM purposes.\n"
"\n"
" Note: S1, S2, D1, D2, T1, T2, T3 are compile-time constants defined in gemm_permute.cu."
" Runtime specification of these values is not supported."
" These values along with alignment requirements place constraints on supported matrix sizes.\n"
"\n"
" Note: X, Y above may refer to M, N or K dimensions of GEMM problem, depending on the tensor considered (A, B or D)."
" For the output tensor D the values correspond directly to dimensions of D, whereas for A and B the original dimensions"
" X', Y' are inferred from the ones supplied to the GEMM, taking into account the permute operation.\n"
"\n"
"Options:\n"
"\n"
" --help If specified, displays this usage statement.\n\n"
" --batch-count=<int> Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n"
" --m=<int> Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n"
" --n=<int> Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n"
" --k=<int> Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=384)\n"
" --alpha=<f32> Epilogue scalar alpha (real part)\n"
" --beta=<f32> Epilogue scalar beta (real part)\n\n"
" --iterations=<int> Number of profiling iterations to perform.\n"
" --reference-check=<bool> If true, performs reference check.\n"
" --verbose=<bool> If true, prints problem sizes and batching structure.\n"
"\n"
"Examples:\n"
"\n"
"# Runs a batched GEMM with 96 batches\n"
"$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96\n"
"\n"
"# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n"
"$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96 --k=1024 --verbose=true\n"
"\n"
"# Execute batched GEMM and profile with NSight\n"
"$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n"
"\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s, bool batched) const {
// Number of real-valued multiply-adds
int64_t fmas = int64_t();
fmas += problem_each.product() * (batched ? batch_count : 1);
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace { // (anonymous)
/// Dimension-generic permutation loop
template<int I, typename Element, typename Layout, typename PermuteOp, typename Coord>
void permute_host_impl(
cutlass::TensorView<Element const, Layout> const & input,
cutlass::TensorView<Element, Layout> const & output,
PermuteOp && permute,
Coord & coord
) {
static_assert(Layout::kRank == Coord::kRank, "Incompatible Layout and Coord types");
if constexpr (I == Coord::kRank) {
output.at(permute(coord)) = input.at(coord);
}
else {
for (coord[I] = 0; coord[I] < input.extent(I); ++coord[I]) {
permute_host_impl<I+1>(input, output, std::forward<PermuteOp>(permute), coord);
}
}
}
} // namespace (anonymous)
/// Perform a reference (host-based) permutation of an input tensor
template<typename PermuteLayout, typename Element, typename Layout>
void permute_host(
cutlass::TensorView<Element const, Layout> const &input,
cutlass::TensorView<Element, Layout> const &output,
int batch_count) {
Layout layout = input.layout();
cutlass::MatrixCoord extent = input.extent();
std::size_t num_elems = layout.capacity(extent) * batch_count;
std::vector<Element> h_input(num_elems);
cutlass::device_memory::copy_to_host(h_input.data(), input.data(), num_elems);
std::vector<Element> h_output(num_elems);
using Info = PermuteInfo<PermuteLayout>;
using TensorLayout = typename Info::Layout;
auto shape_orig = Info::original_shape(extent, batch_count);
auto shape_perm = Info::permute(shape_orig);
cutlass::TensorView<Element const, TensorLayout> view_input(h_input.data(), TensorLayout::packed(shape_orig), shape_orig);
cutlass::TensorView<Element, TensorLayout> view_output(h_output.data(), TensorLayout::packed(shape_perm), shape_perm);
decltype(shape_orig) coord;
permute_host_impl<0>(view_input, view_output, Info::permute, coord);
cutlass::device_memory::copy_to_device(output.data(), h_output.data(), num_elems);
}
template<typename Layout>
struct LayoutInfo;
template<>
struct LayoutInfo<cutlass::layout::RowMajor> {
static std::string name() { return "RowMajor"; }
};
template<>
struct LayoutInfo<cutlass::layout::ColumnMajor> {
static std::string name() { return "ColumnMajor"; }
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename ElementA, typename ElementB, typename ElementC>
class Testbed {
private:
//
// Data members
//
Options & options;
/// Initialization
cutlass::Distribution::Kind init_A;
cutlass::Distribution::Kind init_B;
cutlass::Distribution::Kind init_C;
uint32_t seed;
cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementC> block_D;
public:
//
// Methods
//
Testbed(
Options &options_,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint32_t seed_ = 3090
):
options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
private:
/// Print permutation info for one tensor
template<typename PermuteLayout>
void print_tensor_info(
std::ostream & os,
std::string const &tensor_name,
int row_dim,
int col_dim) {
cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim));
using Info = PermuteInfo<PermuteLayout>;
os << "tensor " << tensor_name << ": " << Info::desc() << "\n";
os << " extent: [" << extent.row() << ", " << extent.column() << "]";
if (Info::kBatched) {
os << ", batch count: " << options.batch_count;
}
os << "\n";
if (!cutlass::layout::is_trivial_permute<PermuteLayout>) {
auto shape_orig = Info::original_shape(extent, options.batch_count);
auto shape_perm = Info::permute(shape_orig);
os << " original: [" << shape_orig << "]\n";
os << " permuted: [" << shape_perm << "]\n";
}
}
/// Check shape compatibility for one tensor
template<typename Layout, typename PermuteLayout, int Alignment>
bool check_tensor_shape(
std::string const &tensor_name,
int row_dim,
int col_dim) {
cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim));
using Info = PermuteInfo<PermuteLayout>;
auto rowAlign = cutlass::platform::is_same<Layout, cutlass::layout::ColumnMajor>::value ? Alignment : 1;
auto colAlign = cutlass::platform::is_same<Layout, cutlass::layout::RowMajor>::value ? Alignment : 1;
auto rowFactor = Info::kRowFactor * rowAlign;
auto colFactor = Info::kColumnFactor * colAlign;
// Assumes row-major layout
bool const valid_row = extent.row() % rowFactor == 0;
if (!valid_row) {
std::cerr << "Invalid tensor " << tensor_name << " row size = " << extent.row() << ", "
"must be divisible by " << rowFactor << ", "
"required by " << Info::name() <<
(rowAlign > 1 ? (" and alignment of " + std::to_string(rowAlign)) : "") << std::endl;
}
bool const valid_col = extent.column() % colFactor == 0;
if (!valid_col) {
std::cerr << "Invalid tensor " << tensor_name << " column size = " << extent.column() << ", "
"must be divisible by " << colFactor << ", "
"required by " << Info::name() <<
(colAlign > 1 ? (" and alignment of " + std::to_string(colAlign)) : "") << std::endl;
}
bool const valid_bsz = options.batch_count % Info::kBatchFactor == 0;
if (!valid_bsz) {
std::cerr << "Invalid batch count = " << options.batch_count << ", "
"must be divisible by " << Info::kBatchFactor << ", "
"required by " << Info::name() << std::endl;
}
return valid_row && valid_col && valid_bsz;
}
/// Helper to initialize a tensor view
template <typename Element>
void initialize_tensor_(
Element *ptr,
size_t capacity,
cutlass::Distribution::Kind dist_kind,
uint32_t seed) {
if (dist_kind == cutlass::Distribution::Uniform) {
Element scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<ElementC>::value;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
} else if (bits_output == 16) {
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
} else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
ptr, capacity, seed, scope_max, scope_min, 0);
}
else if (dist_kind == cutlass::Distribution::Gaussian) {
cutlass::reference::device::BlockFillRandomGaussian(
ptr, capacity, seed, Element(), Element(0.5f));
}
else if (dist_kind == cutlass::Distribution::Sequential) {
// Fill with increasing elements
cutlass::reference::device::BlockFillSequential(
ptr, capacity, Element(1), Element());
}
else {
// Fill with all 1s
cutlass::reference::device::BlockFillSequential(
ptr, capacity, Element(), Element(1));
}
}
/// Initializes data structures
void initialize(int batch_count) {
srand(seed);
int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count;
int64_t total_elements_B = options.problem_each.n() * options.problem_each.k() * batch_count;
int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count;
int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count;
// Allocate space
block_A.reset(total_elements_A);
block_B.reset(total_elements_B);
block_C.reset(total_elements_C);
block_D.reset(total_elements_D);
// Initialize input tensors
initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021);
initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022);
initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023);
cutlass::reference::device::BlockFillSequential(
block_D.get(), total_elements_D, ElementC(), ElementC());
}
/// Check device GEMM results against a reference implementation with separate host-based permutation
template<typename Gemm>
bool validate(Gemm const &gemm) {
bool constexpr kBatched = PermuteInfo<typename Gemm::PermuteALayout>::kBatched
|| PermuteInfo<typename Gemm::PermuteBLayout>::kBatched
|| PermuteInfo<typename Gemm::PermuteDLayout>::kBatched;
int const batch_count = kBatched ? options.batch_count : 1;
cutlass::gemm::GemmCoord problem = options.problem_each;
cutlass::MatrixCoord extent_A{problem.m(), problem.k()};
cutlass::MatrixCoord extent_B{problem.k(), problem.n()};
cutlass::MatrixCoord extent_C{problem.m(), problem.n()};
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
LayoutA layout_A(LayoutA::packed(extent_A));
LayoutB layout_B(LayoutB::packed(extent_B));
LayoutC layout_C(LayoutC::packed(extent_C));
auto size_A = layout_A.capacity(extent_A) * batch_count;
auto size_B = layout_B.capacity(extent_B) * batch_count;
auto size_C = layout_C.capacity(extent_C) * batch_count;
cutlass::TensorView<ElementA, LayoutA> view_A(block_A.get(), layout_A, extent_A);
cutlass::TensorView<ElementB, LayoutB> view_B(block_B.get(), layout_B, extent_B);
cutlass::TensorView<ElementC, LayoutC> view_C(block_C.get(), layout_C, extent_C);
cutlass::TensorView<ElementC, LayoutC> view_D(block_D.get(), layout_C, extent_C);
cutlass::DeviceAllocation<ElementA> block_A_perm(size_A);
cutlass::DeviceAllocation<ElementA> block_B_perm(size_B);
cutlass::TensorView<ElementA, LayoutA> view_A_perm(block_A_perm.get(), layout_A, extent_A);
cutlass::TensorView<ElementB, LayoutB> view_B_perm(block_B_perm.get(), layout_B, extent_B);
permute_host<typename Gemm::PermuteALayout>(view_A.const_view(), view_A_perm, batch_count);
permute_host<typename Gemm::PermuteBLayout>(view_B.const_view(), view_B_perm, batch_count);
cutlass::DeviceAllocation<ElementC> block_D_ref(size_C);
cutlass::TensorView<ElementC, LayoutC> view_D_ref(block_D_ref.get(), layout_C, extent_C);
using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
// Reference GEMM
cutlass::reference::device::GemmComplex<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
typename EpilogueOutputOp::ElementCompute,
typename Gemm::ElementAccumulator
>(
problem,
options.alpha,
view_A_perm,
Gemm::kTransformA,
view_B_perm,
Gemm::kTransformB,
options.beta,
view_C,
view_D_ref,
ElementAccumulator(0),
batch_count,
options.problem_each.m() * options.problem_each.k(),
options.problem_each.n() * options.problem_each.k(),
options.problem_each.m() * options.problem_each.n(),
options.problem_each.m() * options.problem_each.n()
);
cutlass::DeviceAllocation<ElementC> block_D_perm(size_C);
cutlass::TensorView<ElementC, LayoutC> view_D_perm(block_D_perm.get(), layout_C, extent_C);
permute_host<typename Gemm::PermuteDLayout>(view_D_ref.const_view(), view_D_perm, batch_count);
// Reference check
return cutlass::reference::device::BlockCompareEqual(view_D_perm.data(), view_D.data(), size_C);
}
public:
template<typename Gemm>
bool profile_GEMM_permute() {
using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using PermuteALayout = typename Gemm::PermuteALayout;
using PermuteBLayout = typename Gemm::PermuteBLayout;
using PermuteDLayout = typename Gemm::PermuteDLayout;
bool constexpr kBatched = PermuteInfo<PermuteALayout>::kBatched
|| PermuteInfo<PermuteBLayout>::kBatched
|| PermuteInfo<PermuteDLayout>::kBatched;
std::cout << "\n"
"====================================================\n"
<< (kBatched ? "Batched" : "Normal") << " GEMM:"
<< "\n A=" << LayoutInfo<LayoutA>::name() << "," << PermuteInfo<PermuteALayout>::name()
<< "\n B=" << LayoutInfo<LayoutB>::name() << "," << PermuteInfo<PermuteBLayout>::name()
<< "\n D=" << LayoutInfo<LayoutC>::name() << "," << PermuteInfo<PermuteDLayout>::name()
<< "\n"
"====================================================\n";
if (options.verbose) {
print_tensor_info<PermuteALayout>(std::cout, "A", 0, 2);
print_tensor_info<PermuteBLayout>(std::cout, "B", 2, 1);
print_tensor_info<PermuteDLayout>(std::cout, "D", 0, 1);
}
std::cout << std::endl;
bool valid = true;
valid &= check_tensor_shape<LayoutA, PermuteALayout, Gemm::kAlignmentA>("A", 0, 2);
valid &= check_tensor_shape<LayoutB, PermuteBLayout, Gemm::kAlignmentB>("B", 2, 1);
valid &= check_tensor_shape<LayoutC, PermuteDLayout, Gemm::kAlignmentC>("D", 0, 1);
if (!valid)
{
std::cout << "Skipped test" << std::endl;
return true;
}
int const batch_count = kBatched ? options.batch_count : 1;
// Initialize the problem
initialize(batch_count);
// Configure the GEMM arguments
using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp;
typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta);
// Please make sure all problem_sizes are the same for kBatched mode
auto problem = options.problem_each;
cutlass::MatrixCoord extent_A{problem.m(), problem.k()};
cutlass::MatrixCoord extent_B{problem.k(), problem.n()};
cutlass::MatrixCoord extent_C{problem.m(), problem.n()};
LayoutA layout_A(LayoutA::packed(extent_A));
LayoutB layout_B(LayoutB::packed(extent_B));
LayoutC layout_C(LayoutC::packed(extent_C));
// Configure GEMM arguments
typename Gemm::Arguments arguments{
kBatched ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm,
problem,
batch_count,
epilogue_op,
(void*)block_A.get(),
(void*)block_B.get(),
(void*)block_C.get(),
(void*)block_D.get(),
// For any non-trivial permute the batch stride must be set to 0
cutlass::layout::is_trivial_permute<PermuteALayout> ? layout_A.capacity(extent_A) : 0,
cutlass::layout::is_trivial_permute<PermuteBLayout> ? layout_B.capacity(extent_B) : 0,
layout_C.capacity(extent_C),
cutlass::layout::is_trivial_permute<PermuteDLayout> ? layout_C.capacity(extent_C) : 0,
layout_A.stride(0),
layout_B.stride(0),
layout_C.stride(0),
layout_C.stride(0),
};
// Initialize the GEMM object
Gemm gemm_normal;
CHECK_CUTLASS_CALL(gemm_normal.initialize(arguments, nullptr), return false);
// Run the normal GEMM object
CHECK_CUTLASS_CALL(gemm_normal.run(), return false);
// Wait for completion
CHECK_CUDA_CALL(cudaDeviceSynchronize(), return false);
//
// Verify correctness
//
if (options.reference_check) {
if (validate(gemm_normal)) {
std::cout << "\nPassed verification\n" << std::endl;
}
else {
std::cerr << "\n*** Error - problem failed the QA check ***\n" << std::endl;
return false;
}
}
// Warm-up run of the normal GEMM object
CHECK_CUTLASS_CALL(gemm_normal.run(), return false);
// Construct events
cudaEvent_t events[2];
for (auto & event : events) {
CHECK_CUDA_CALL(cudaEventCreate(&event), return false);
}
// Record an event at the start of a series of GEMM operations
CHECK_CUDA_CALL(cudaEventRecord(events[0]), return false);
// Run profiling loop
for (int iter = 0; iter < options.iterations; ++iter) {
gemm_normal();
}
// Record an event when the GEMM operations have been launched.
CHECK_CUDA_CALL(cudaEventRecord(events[1]), return false);
// Wait for work on the device to complete.
CHECK_CUDA_CALL(cudaEventSynchronize(events[1]), return false);
// Measure elapsed runtime
float runtime_total_ms = 0;
CHECK_CUDA_CALL(cudaEventElapsedTime(&runtime_total_ms, events[0], events[1]), return false);
// Compute average runtime and GFLOPs.
double runtime_avg_ms = double(runtime_total_ms) / double(options.iterations);
double gflops = options.gflops(runtime_avg_ms / 1000.0, kBatched);
// Cleanup
for (auto event : events) {
CHECK_CUDA_CALL(cudaEventDestroy(event), return false);
}
std::cout << " Runtime: " << runtime_avg_ms << " ms\n"
" GFLOPs: " << gflops << std::endl;
return true;
}
};
/// Shorthand alist for GEMM instantiations
template<typename LayoutA, typename PermuteALayout,
typename LayoutB, typename PermuteBLayout,
typename LayoutC, typename PermuteDLayout>
using GemmPermute = cutlass::gemm::device::GemmUniversal<
ElementInput, LayoutA,
ElementInput, LayoutB,
ElementOutput, LayoutC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
AlignmentC, //128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementAccumulator
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
4, /*kStages*/
AlignmentA, /*AlignmentA*/
AlignmentB, /*AlignmentB*/
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone,
false, /*GatherA*/
false, /*GatherB*/
false, /*ScatterD*/
PermuteDLayout, /*PermuteDLayout*/
typename cutlass::layout::InversePermute<PermuteALayout>::type, /*PermuteALayout*/
typename cutlass::layout::InversePermute<PermuteBLayout>::type /*PermuteBLayout*/
>;
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
//
// This example uses mma.sync to directly access Tensor Cores to achieve peak performance.
//
cudaDeviceProp props;
CHECK_CUDA_CALL(cudaGetDeviceProperties(&props, 0), return EXIT_FAILURE);
if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) {
//
// This example requires an NVIDIA Ampere-architecture GPU.
//
std::cout << "CUTLASS's GEMM+Permute example requires a GPU of NVIDIA's Ampere Architecture "
"or later (compute capability 80 or greater).\n";
return EXIT_SUCCESS;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return EXIT_SUCCESS;
}
if (options.error) {
std::cerr << "Aborting execution." << std::endl;
return EXIT_FAILURE;
}
//
// Define GEMM types to test
//
//
// TTT (Row-major) GEMMs
//
using TTTGemmNormalPermuteNone = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using TTTGemmNormalPermuteA = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using TTTGemmNormalPermuteAD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
>;
using TTTGemmNormalPermuteB = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using TTTGemmNormalPermuteBD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
>;
using TTTGemmNormalPermuteD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
>;
using TTTGemmNormalPermuteAB = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using TTTGemmNormalPermuteABD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
>;
//
// NNN (Col-major) GEMMs
//
using NNNGemmNormalPermuteNone = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using NNNGemmNormalPermuteA = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using NNNGemmNormalPermuteAD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
>;
using NNNGemmNormalPermuteB = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using NNNGemmNormalPermuteBD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
>;
using NNNGemmNormalPermuteD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
>;
using NNNGemmNormalPermuteAB = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using NNNGemmNormalPermuteABD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
>;
//
// NNT (Col-major inputs, row-major output) GEMMs
//
using NNTGemmNormalPermuteNone = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using NNTGemmNormalPermuteA = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using NNTGemmNormalPermuteAD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
>;
using NNTGemmNormalPermuteB = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using NNTGemmNormalPermuteBD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
>;
using NNTGemmNormalPermuteD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
>;
using NNTGemmNormalPermuteAB = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using NNTGemmNormalPermuteABD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor<T1, T2, T3>
>;
//
// TTN (Row-major inputs, col-major output) GEMMs
//
using TTNGemmNormalPermuteNone = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using TTNGemmNormalPermuteA = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using TTNGemmNormalPermuteAD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
>;
using TTNGemmNormalPermuteB = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using TTNGemmNormalPermuteBD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
>;
using TTNGemmNormalPermuteD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
>;
using TTNGemmNormalPermuteAB = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using TTNGemmNormalPermuteABD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor<S1, S2>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor<T1, T2, T3>
>;
//
// TTT (Row-major) BMMs
//
using TTTGemmBatchedPermuteA = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using TTTGemmBatchedPermuteAD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
>;
using TTTGemmBatchedPermuteB = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
cutlass::layout::RowMajor, cutlass::layout::NoPermute
>;
using TTTGemmBatchedPermuteBD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
>;
using TTTGemmBatchedPermuteD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
>;
using TTTGemmBatchedPermuteAB = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::NoPermute,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
>;
using TTTGemmBatchedPermuteABD = GemmPermute<
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>,
cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor<D1>
>;
//
// NNN (Col-major) BMMs
//
using NNNGemmBatchedPermuteA = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using NNNGemmBatchedPermuteAD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>
>;
using NNNGemmBatchedPermuteB = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using NNNGemmBatchedPermuteBD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>
>;
using NNNGemmBatchedPermuteD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>
>;
using NNNGemmBatchedPermuteAB = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
cutlass::layout::ColumnMajor, cutlass::layout::NoPermute
>;
using NNNGemmBatchedPermuteABD = GemmPermute<
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>,
cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor<D1>
>;
//
// Profile it
//
Testbed<ElementInput, ElementInput, ElementOutput> testbed(options);
bool result = true;
result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteNone>();
result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteA>();
result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteAD>();
result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteB>();
result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteBD>();
result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteD>();
result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteAB>();
result &= testbed.profile_GEMM_permute<TTTGemmNormalPermuteABD>();
result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteNone>();
result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteA>();
result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteAD>();
result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteB>();
result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteBD>();
result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteD>();
result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteAB>();
result &= testbed.profile_GEMM_permute<NNNGemmNormalPermuteABD>();
result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteNone>();
result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteA>();
result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteAD>();
result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteB>();
result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteBD>();
result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteD>();
result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteAB>();
result &= testbed.profile_GEMM_permute<NNTGemmNormalPermuteABD>();
result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteNone>();
result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteA>();
result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteAD>();
result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteB>();
result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteBD>();
result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteD>();
result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteAB>();
result &= testbed.profile_GEMM_permute<TTNGemmNormalPermuteABD>();
result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteA>();
result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteAD>();
result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteB>();
result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteBD>();
result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteD>();
result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteAB>();
result &= testbed.profile_GEMM_permute<TTTGemmBatchedPermuteABD>();
result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteA>();
result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteAD>();
result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteB>();
result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteBD>();
result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteD>();
result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteAB>();
result &= testbed.profile_GEMM_permute<NNNGemmBatchedPermuteABD>();
std::cout << "\n"
"====================================================\n"
"Finished (" << (result ? "PASS" : "FAIL") << ")\n"
"====================================================" << std::endl;
return result ? EXIT_SUCCESS : EXIT_FAILURE;
}
/////////////////////////////////////////////////////////////////////////////////////////////////