cutlass/examples/21_quaternion_gemm/quaternion_gemm.cu

455 lines
17 KiB
Plaintext

/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
//
// Methods
//
Result(
double runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess
):
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
cutlass::gemm::GemmCoord problem_size;
int batch_count;
cutlass::Quaternion<float> alpha;
cutlass::Quaternion<float> beta;
bool reference_check;
int iterations;
Options():
help(false),
problem_size({1024, 1024, 1024}),
batch_count(1),
reference_check(true),
iterations(20),
alpha(1),
beta() { }
bool valid() {
return true;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("batch", batch_count);
cmd.get_cmd_line_argument("alpha", alpha.w());
cmd.get_cmd_line_argument("alpha_i", alpha.x());
cmd.get_cmd_line_argument("alpha_j", alpha.y());
cmd.get_cmd_line_argument("alpha_k", alpha.z());
cmd.get_cmd_line_argument("beta", beta.w());
cmd.get_cmd_line_argument("beta_i", beta.x());
cmd.get_cmd_line_argument("beta_j", beta.y());
cmd.get_cmd_line_argument("beta_k", beta.z());
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "21_quaternion_gemm example\n\n"
<< " This example uses the CUTLASS Library to execute Quaternion GEMM computations.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --batch=<int> Number of GEMM operations executed in one batch\n"
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
<< " --alpha_i=<f32> Epilogue scalar alpha_i (imaginary part)\n"
<< " --alpha_j=<f32> Epilogue scalar alpha_j (imaginary part)\n"
<< " --alpha_k=<f32> Epilogue scalar alpha_k (imaginary part)\n"
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
<< " --beta_i=<f32> Epilogue scalar beta_i (imaginary part)\n\n"
<< " --beta_j=<f32> Epilogue scalar beta_j (imaginary part)\n\n"
<< " --beta_k=<f32> Epilogue scalar beta_k (imaginary part)\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/21_quaternion_gemm/21_quaternion_gemm --batch=7 --m=1024 --n=512 --k=1024 \\\n"
<< " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of real-valued multiply-adds
int64_t fmas = problem_size.product() * batch_count * 16;
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// The code section below describes datatype for input, output matrices and computation between
// elements in input matrices.
using precision = float;
using Element = cutlass::Quaternion<float>;
using ElementComputeEpilogue = Element; // <- data type of epilogue operations
using ElementAccumulator = Element; // <- data type of accumulator
using ElementInputA = Element; // <- data type of elements in input matrix A
using ElementInputB = Element; // <- data type of elements in input matrix B
using ElementOutput = Element; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices. Column Major for
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassSimt;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm50;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<64, 64, 4>; // <- threadblock tile M = 64, N = 64, K = 8
// This code section describes tile size a warp will compute
using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 16, 4>; // <- warp tile M = 32, N = 16, K = 8
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<1, 1, 1>; // <- MMA Op tile M = 1, N = 1, K = 1
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- Defaults
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 2;
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;
int run(Options options) {
// PASS/FAIL status
bool passed = true;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size = options.problem_size;
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// CUTLASS kernel
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
// reference kernel
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
4,
-4,
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
4,
-4,
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
4,
-4,
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
// Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
// instantiated CUTLASS kernel
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
tensor_a.device_ref(), // <- reference to matrix A on device
tensor_b.device_ref(), // <- reference to matrix B on device
tensor_c.device_ref(), // <- reference to matrix C on device
tensor_d.device_ref(), // <- reference to matrix D on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm_op;
// Check the problem size is supported or not
cutlass::Status status = gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
// Result structure
Result result;
//
// Construct events
//
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
}
// Record an event at the start of a series of GEMMs
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
//
// Run profiling loop
//
for (int iter = 0; iter < options.iterations; ++iter) {
// Launch initialized CUTLASS kernel
status = gemm_op();
CUTLASS_CHECK(status);
}
//
// Stop profiling loop
//
// Record an event when the GEMMs are complete
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return -1;
}
// Compute average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
if (options.reference_check) {
// Create instantiation for device reference gemm kernel
cutlass::reference::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementComputeEpilogue> gemm_device;
// Launch device reference gemm kernel
gemm_device(problem_size,
alpha,
tensor_a.device_ref(),
tensor_b.device_ref(),
beta,
tensor_c.device_ref(),
tensor_ref_d.device_ref());
// Wait for kernels to finish
cudaDeviceSynchronize();
// Copy output data from CUTLASS and reference kernel to host for comparison
tensor_d.sync_host();
tensor_ref_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
passed &= cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
}
if (passed) {
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
std::cout << " GFLOPs: " << result.gflops << std::endl;
}
std::cout << (passed ? "Passed" : "Failed") << std::endl;
return (passed ? 0 : -1);
}
int main(int argc, char const** argv) {
Options options;
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
printf("%d x %d x %d Single Precision Quaternion Matrix Multiply\n", \
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
return run(options);
}