455 lines
17 KiB
Plaintext
455 lines
17 KiB
Plaintext
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
#include <iostream>
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/gemm/device/gemm.h"
|
|
|
|
#include "cutlass/util/command_line.h"
|
|
#include "cutlass/util/host_tensor.h"
|
|
#include "cutlass/util/reference/device/gemm.h"
|
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
|
|
#include "helper.h"
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Result structure
|
|
struct Result {
|
|
|
|
double runtime_ms;
|
|
double gflops;
|
|
cutlass::Status status;
|
|
cudaError_t error;
|
|
bool passed;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
Result(
|
|
double runtime_ms = 0,
|
|
double gflops = 0,
|
|
cutlass::Status status = cutlass::Status::kSuccess,
|
|
cudaError_t error = cudaSuccess
|
|
):
|
|
runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { }
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Command line options parsing
|
|
struct Options {
|
|
|
|
bool help;
|
|
|
|
cutlass::gemm::GemmCoord problem_size;
|
|
int batch_count;
|
|
cutlass::Quaternion<float> alpha;
|
|
cutlass::Quaternion<float> beta;
|
|
|
|
bool reference_check;
|
|
int iterations;
|
|
|
|
Options():
|
|
help(false),
|
|
problem_size({1024, 1024, 1024}),
|
|
batch_count(1),
|
|
reference_check(true),
|
|
iterations(20),
|
|
alpha(1),
|
|
beta() { }
|
|
|
|
bool valid() {
|
|
return true;
|
|
}
|
|
|
|
// Parses the command line
|
|
void parse(int argc, char const **args) {
|
|
cutlass::CommandLine cmd(argc, args);
|
|
|
|
if (cmd.check_cmd_line_flag("help")) {
|
|
help = true;
|
|
}
|
|
|
|
cmd.get_cmd_line_argument("m", problem_size.m());
|
|
cmd.get_cmd_line_argument("n", problem_size.n());
|
|
cmd.get_cmd_line_argument("k", problem_size.k());
|
|
cmd.get_cmd_line_argument("batch", batch_count);
|
|
|
|
cmd.get_cmd_line_argument("alpha", alpha.w());
|
|
cmd.get_cmd_line_argument("alpha_i", alpha.x());
|
|
cmd.get_cmd_line_argument("alpha_j", alpha.y());
|
|
cmd.get_cmd_line_argument("alpha_k", alpha.z());
|
|
|
|
cmd.get_cmd_line_argument("beta", beta.w());
|
|
cmd.get_cmd_line_argument("beta_i", beta.x());
|
|
cmd.get_cmd_line_argument("beta_j", beta.y());
|
|
cmd.get_cmd_line_argument("beta_k", beta.z());
|
|
|
|
cmd.get_cmd_line_argument("iterations", iterations);
|
|
|
|
}
|
|
|
|
/// Prints the usage statement.
|
|
std::ostream & print_usage(std::ostream &out) const {
|
|
|
|
out << "21_quaternion_gemm example\n\n"
|
|
<< " This example uses the CUTLASS Library to execute Quaternion GEMM computations.\n\n"
|
|
<< "Options:\n\n"
|
|
<< " --help If specified, displays this usage statement.\n\n"
|
|
<< " --m=<int> GEMM M dimension\n"
|
|
<< " --n=<int> GEMM N dimension\n"
|
|
<< " --k=<int> GEMM K dimension\n"
|
|
<< " --batch=<int> Number of GEMM operations executed in one batch\n"
|
|
<< " --alpha=<f32> Epilogue scalar alpha (real part)\n"
|
|
<< " --alpha_i=<f32> Epilogue scalar alpha_i (imaginary part)\n"
|
|
<< " --alpha_j=<f32> Epilogue scalar alpha_j (imaginary part)\n"
|
|
<< " --alpha_k=<f32> Epilogue scalar alpha_k (imaginary part)\n"
|
|
<< " --beta=<f32> Epilogue scalar beta (real part)\n\n"
|
|
<< " --beta_i=<f32> Epilogue scalar beta_i (imaginary part)\n\n"
|
|
<< " --beta_j=<f32> Epilogue scalar beta_j (imaginary part)\n\n"
|
|
<< " --beta_k=<f32> Epilogue scalar beta_k (imaginary part)\n\n"
|
|
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
|
|
|
out << "\n\nExamples:\n\n"
|
|
<< "$ ./examples/21_quaternion_gemm/21_quaternion_gemm --batch=7 --m=1024 --n=512 --k=1024 \\\n"
|
|
<< " --alpha=2 --alpha_i=-2 --beta=0.707 --beta_i=-.707\n\n";
|
|
|
|
return out;
|
|
}
|
|
|
|
/// Compute performance in GFLOP/s
|
|
double gflops(double runtime_s) const {
|
|
|
|
// Number of real-valued multiply-adds
|
|
int64_t fmas = problem_size.product() * batch_count * 16;
|
|
|
|
// Two flops per multiply-add
|
|
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
|
}
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// The code section below describes datatype for input, output matrices and computation between
|
|
// elements in input matrices.
|
|
using precision = float;
|
|
using Element = cutlass::Quaternion<float>;
|
|
using ElementComputeEpilogue = Element; // <- data type of epilogue operations
|
|
using ElementAccumulator = Element; // <- data type of accumulator
|
|
using ElementInputA = Element; // <- data type of elements in input matrix A
|
|
using ElementInputB = Element; // <- data type of elements in input matrix B
|
|
using ElementOutput = Element; // <- data type of elements in output matrix D
|
|
|
|
// The code section below describes matrix layout of input and output matrices. Column Major for
|
|
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
|
|
using LayoutInputA = cutlass::layout::RowMajor;
|
|
using LayoutInputB = cutlass::layout::ColumnMajor;
|
|
using LayoutOutput = cutlass::layout::RowMajor;
|
|
|
|
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
|
using MMAOp = cutlass::arch::OpClassSimt;
|
|
|
|
// This code section describes CUDA SM architecture number
|
|
using SmArch = cutlass::arch::Sm50;
|
|
|
|
// This code section describes the tile size a thread block will compute
|
|
using ShapeMMAThreadBlock =
|
|
cutlass::gemm::GemmShape<64, 64, 4>; // <- threadblock tile M = 64, N = 64, K = 8
|
|
// This code section describes tile size a warp will compute
|
|
using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 16, 4>; // <- warp tile M = 32, N = 16, K = 8
|
|
// This code section describes the size of MMA op
|
|
using ShapeMMAOp = cutlass::gemm::GemmShape<1, 1, 1>; // <- MMA Op tile M = 1, N = 1, K = 1
|
|
|
|
// This code section describes how threadblocks are scheduled on GPU
|
|
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- Defaults
|
|
|
|
// This code section describes the epilogue part of the kernel
|
|
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
|
|
ElementOutput, // <- data type of output matrix
|
|
128 / cutlass::sizeof_bits<ElementOutput>::value, // <- the number of elements per vectorized
|
|
// memory access. For a byte, it's 16
|
|
// elements. This becomes the vector width of
|
|
// math instructions in the epilogue too
|
|
ElementAccumulator, // <- data type of accumulator
|
|
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
|
|
|
|
// Number of pipelines you want to use
|
|
constexpr int NumStages = 2;
|
|
|
|
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
|
|
LayoutInputA,
|
|
ElementInputB,
|
|
LayoutInputB,
|
|
ElementOutput,
|
|
LayoutOutput,
|
|
ElementAccumulator,
|
|
MMAOp,
|
|
SmArch,
|
|
ShapeMMAThreadBlock,
|
|
ShapeMMAWarp,
|
|
ShapeMMAOp,
|
|
EpilogueOp,
|
|
SwizzleThreadBlock,
|
|
NumStages>;
|
|
|
|
int run(Options options) {
|
|
|
|
// PASS/FAIL status
|
|
bool passed = true;
|
|
|
|
// Create a tuple of problem size for matrix multiplication
|
|
cutlass::gemm::GemmCoord problem_size = options.problem_size;
|
|
|
|
// Initialize tensors using CUTLASS helper functions
|
|
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(
|
|
problem_size.mk()); // <- Create matrix A with dimensions M x K
|
|
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(
|
|
problem_size.kn()); // <- Create matrix B with dimensions K x N
|
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(
|
|
problem_size.mn()); // <- Create matrix C with dimensions M x N
|
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(
|
|
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
|
// CUTLASS kernel
|
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(
|
|
problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
|
|
// reference kernel
|
|
|
|
// Fill input and output matrices on host using CUTLASS helper functions
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
tensor_a.host_view(),
|
|
1,
|
|
4,
|
|
-4,
|
|
0); // <- Fill matrix A on host with uniform-distribution random data
|
|
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
tensor_b.host_view(),
|
|
1,
|
|
4,
|
|
-4,
|
|
0); // <- Fill matrix B on host with uniform-distribution random data
|
|
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
tensor_c.host_view(),
|
|
1,
|
|
4,
|
|
-4,
|
|
0); // <- Fill matrix C on host with uniform-distribution random data
|
|
|
|
cutlass::reference::host::TensorFill(
|
|
tensor_d.host_view()); // <- fill matrix D on host with zeros
|
|
cutlass::reference::host::TensorFill(
|
|
tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
|
|
|
|
// Copy data from host to GPU
|
|
tensor_a.sync_device();
|
|
tensor_b.sync_device();
|
|
tensor_c.sync_device();
|
|
tensor_d.sync_device();
|
|
tensor_ref_d.sync_device();
|
|
|
|
// Initialize alpha and beta for dot product computation
|
|
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
|
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
|
|
|
// Split K dimension into 1 partitions
|
|
int split_k_slices = 1;
|
|
|
|
// Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
|
|
// instantiated CUTLASS kernel
|
|
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
tensor_a.device_ref(), // <- reference to matrix A on device
|
|
tensor_b.device_ref(), // <- reference to matrix B on device
|
|
tensor_c.device_ref(), // <- reference to matrix C on device
|
|
tensor_d.device_ref(), // <- reference to matrix D on device
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
split_k_slices}; // <- k-dimension split factor
|
|
|
|
// Using the arguments, query for extra workspace required for matrix multiplication computation
|
|
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
|
|
|
// Allocate workspace memory
|
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
|
|
|
// Instantiate CUTLASS kernel depending on templates
|
|
Gemm gemm_op;
|
|
|
|
// Check the problem size is supported or not
|
|
cutlass::Status status = gemm_op.can_implement(arguments);
|
|
CUTLASS_CHECK(status);
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
status = gemm_op.initialize(arguments, workspace.get());
|
|
CUTLASS_CHECK(status);
|
|
|
|
// Result structure
|
|
Result result;
|
|
|
|
//
|
|
// Construct events
|
|
//
|
|
|
|
cudaEvent_t events[2];
|
|
|
|
for (auto & event : events) {
|
|
result.error = cudaEventCreate(&event);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
// Record an event at the start of a series of GEMMs
|
|
result.error = cudaEventRecord(events[0]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
//
|
|
// Run profiling loop
|
|
//
|
|
|
|
for (int iter = 0; iter < options.iterations; ++iter) {
|
|
|
|
// Launch initialized CUTLASS kernel
|
|
status = gemm_op();
|
|
CUTLASS_CHECK(status);
|
|
|
|
}
|
|
|
|
//
|
|
// Stop profiling loop
|
|
//
|
|
|
|
// Record an event when the GEMMs are complete
|
|
result.error = cudaEventRecord(events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
// Wait for work on the device to complete.
|
|
result.error = cudaEventSynchronize(events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
// Measure elapsed runtime
|
|
float runtime_ms = 0;
|
|
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
// Compute average runtime and GFLOPs.
|
|
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
|
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
|
|
|
// Cleanup
|
|
for (auto event : events) {
|
|
(void)cudaEventDestroy(event);
|
|
}
|
|
|
|
if (options.reference_check) {
|
|
|
|
// Create instantiation for device reference gemm kernel
|
|
cutlass::reference::device::Gemm<ElementInputA,
|
|
LayoutInputA,
|
|
ElementInputB,
|
|
LayoutInputB,
|
|
ElementOutput,
|
|
LayoutOutput,
|
|
ElementComputeEpilogue,
|
|
ElementComputeEpilogue> gemm_device;
|
|
|
|
// Launch device reference gemm kernel
|
|
gemm_device(problem_size,
|
|
alpha,
|
|
tensor_a.device_ref(),
|
|
tensor_b.device_ref(),
|
|
beta,
|
|
tensor_c.device_ref(),
|
|
tensor_ref_d.device_ref());
|
|
|
|
// Wait for kernels to finish
|
|
cudaDeviceSynchronize();
|
|
|
|
// Copy output data from CUTLASS and reference kernel to host for comparison
|
|
tensor_d.sync_host();
|
|
tensor_ref_d.sync_host();
|
|
|
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
|
passed &= cutlass::reference::host::TensorEquals(
|
|
tensor_d.host_view(),
|
|
tensor_ref_d.host_view());
|
|
|
|
}
|
|
|
|
if (passed) {
|
|
std::cout << "Runtime: " << result.runtime_ms << " ms" << std::endl;
|
|
std::cout << " GFLOPs: " << result.gflops << std::endl;
|
|
}
|
|
|
|
std::cout << (passed ? "Passed" : "Failed") << std::endl;
|
|
return (passed ? 0 : -1);
|
|
}
|
|
|
|
int main(int argc, char const** argv) {
|
|
|
|
Options options;
|
|
options.parse(argc, argv);
|
|
|
|
if (options.help) {
|
|
options.print_usage(std::cout) << std::endl;
|
|
return 0;
|
|
}
|
|
|
|
printf("%d x %d x %d Single Precision Quaternion Matrix Multiply\n", \
|
|
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());
|
|
|
|
if (!options.valid()) {
|
|
std::cerr << "Invalid problem." << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
return run(options);
|
|
}
|
|
|