792 lines
27 KiB
Plaintext
792 lines
27 KiB
Plaintext
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
/*
|
|
This example shows how to compute conv2d gradient with respect to weight (wgrad). In wgrad, the K dimension of
|
|
impligit GEMM, corresponding to the sequential reduction loop, is very large (N * P * Q). Split-k with parallel
|
|
reduction is highly effective for such cases. Given split_k_slices parameter, it partitions the K loop into
|
|
split_k_slices chunks and computes partial reductions in parallel across different blocks. After that,
|
|
a parallel reduction kernel is launched to accumulate partial reductions.
|
|
In practice, wgrad requires fp32 accumulation to avoid overflow. When the input is fp16, some care is needed
|
|
to correctly instantiate the GEMM template.
|
|
*/
|
|
|
|
#include <iostream>
|
|
#include <fstream>
|
|
#include <sstream>
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/gemm/device/gemm.h"
|
|
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
|
|
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
|
|
|
#include "cutlass/util/command_line.h"
|
|
#include "cutlass/util/host_tensor.h"
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
#include "cutlass/util/reference/device/gemm.h"
|
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
|
#include "cutlass/util/reference/device/convolution.h"
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
#include "cutlass/reduction/device/reduce_split_k.h"
|
|
#include "cutlass/reduction/thread/reduction_operators.h"
|
|
|
|
#include "helper.h"
|
|
|
|
// The code section below describes datatype for input, output tensors and computation between
|
|
// elements
|
|
// In Wgrad, fp32 accumulation is necessary in practice.
|
|
using ElementAccumulator = float; // Data type of accumulator
|
|
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
|
|
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
|
|
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
|
|
using ElementOutput = cutlass::half_t; // Data type of elements in output tensor
|
|
using ElementC = ElementOutput;
|
|
using ElementCompute = ElementComputeEpilogue;
|
|
using LayoutInputA = cutlass::layout::TensorNHWC;
|
|
using LayoutInputB = cutlass::layout::TensorNHWC;
|
|
using LayoutOutput = cutlass::layout::TensorNHWC;
|
|
|
|
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
|
|
using MMAOp = cutlass::arch::OpClassTensorOp;
|
|
|
|
// This code section describes CUDA SM architecture number
|
|
using SmArch = cutlass::arch::Sm80;
|
|
|
|
// This code section describes the tile size a thread block will compute
|
|
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
|
|
|
|
// This code section describes tile size a warp will compute
|
|
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
|
|
|
|
// This code section describes the size of MMA op
|
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
|
|
|
|
// This code section describes how threadblocks are scheduled on GPU
|
|
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
|
|
|
// Number of pipelines you want to use
|
|
constexpr int NumStages = 3;
|
|
|
|
// This code section describe iterator algorithm selected is Analytic or Optimized
|
|
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
|
|
|
|
// We need two epilogue functors - one for GEMM and another for the final reduction.
|
|
// The epilogue for GEMM is not used, but needed to instantiate the CUTLASS kernel template.
|
|
// Note that, when the input is fp16 and accumulation is fp32, the output of GEMM needs to be fp32,
|
|
// the final reduction is done in fp32, and the reduction epilogue converts fp32 outputs to fp16.
|
|
// Therefore, the output type of the GEMM epilogue is ElementCompute, not ElementOutput.
|
|
|
|
// This code section describes the epilogue part of the kernel, we use default value
|
|
using EpilogueOpGEMM = cutlass::epilogue::thread::LinearCombination<
|
|
ElementCompute, // Data type of output matrix.
|
|
128 / cutlass::sizeof_bits<ElementCompute>::value, // The number of elements per vectorized.
|
|
// memory access. This becomes the vector width of
|
|
// math instructions in the epilogue too.
|
|
ElementAccumulator, // Data type of accumulator
|
|
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
|
|
|
|
// The epilogue functor for reduction. This is the one that is actually used.
|
|
using EpilogueOpReduction = cutlass::epilogue::thread::LinearCombination<
|
|
ElementOutput, // Data type of output matrix.
|
|
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
|
|
// memory access. This becomes the vector width of
|
|
// math instructions in the epilogue too.
|
|
ElementAccumulator, // Data type of accumulator
|
|
ElementComputeEpilogue>; // Data type for alpha/beta in lin
|
|
|
|
using Conv2dWgradKernel = typename cutlass::conv::kernel::DefaultConv2dWgrad<
|
|
ElementInputA, LayoutInputA,
|
|
ElementInputB, LayoutInputB,
|
|
ElementAccumulator, LayoutOutput,
|
|
ElementAccumulator,
|
|
MMAOp,
|
|
SmArch,
|
|
ThreadblockShape,
|
|
WarpShape,
|
|
InstructionShape,
|
|
EpilogueOpGEMM,
|
|
SwizzleThreadBlock,
|
|
NumStages,
|
|
cutlass::arch::OpMultiplyAdd,
|
|
IteratorAlgorithm
|
|
>::Kernel;
|
|
|
|
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<Conv2dWgradKernel>;
|
|
|
|
using EpilogueOutputOp = EpilogueOpReduction;
|
|
|
|
/// Reduction kernel
|
|
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
|
|
ElementAccumulator,
|
|
typename EpilogueOutputOp::ElementAccumulator,
|
|
EpilogueOutputOp::kCount
|
|
>;
|
|
|
|
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
|
|
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
|
|
EpilogueOutputOp,
|
|
ReductionOp
|
|
>;
|
|
|
|
using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
|
|
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Command line options parsing
|
|
struct Options {
|
|
|
|
bool help;
|
|
cutlass::Tensor4DCoord input_size;
|
|
cutlass::Tensor4DCoord filter_size;
|
|
cutlass::Tensor4DCoord padding;
|
|
cutlass::MatrixCoord conv_stride;
|
|
cutlass::MatrixCoord dilation;
|
|
bool reference_check;
|
|
bool measure_performance;
|
|
int iterations;
|
|
bool save_workspace;
|
|
ElementComputeEpilogue alpha;
|
|
ElementComputeEpilogue beta;
|
|
int split_k_slices;
|
|
bool benchmark;
|
|
std::string tag;
|
|
|
|
Options():
|
|
help(false),
|
|
input_size(1, 32, 32, 32),
|
|
filter_size(32, 3, 3, 32),
|
|
padding(1, 1, 1, 1),
|
|
conv_stride(1, 1),
|
|
dilation(1, 1),
|
|
reference_check(true),
|
|
measure_performance(false),
|
|
iterations(20),
|
|
save_workspace(false),
|
|
alpha(1),
|
|
beta(0),
|
|
split_k_slices(8),
|
|
benchmark(false) { }
|
|
|
|
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
|
|
bool valid() {
|
|
|
|
//
|
|
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
|
|
// all pointers, strides, and tensor extents must be divisible by 8 elements.
|
|
//
|
|
int const kAlignment = 8;
|
|
|
|
if ((input_size.c() % kAlignment) ||
|
|
(filter_size.n() % kAlignment)) {
|
|
|
|
// misaligned tensors
|
|
return false;
|
|
}
|
|
|
|
// Invalid padding
|
|
if ((padding.h() != filter_size.h() / 2) ||
|
|
(padding.w() != filter_size.w() / 2)) {
|
|
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Updates input and filter sizes
|
|
void update(
|
|
cutlass::Tensor4DCoord input_size,
|
|
cutlass::Tensor4DCoord filter_size,
|
|
cutlass::MatrixCoord stride) {
|
|
|
|
this->input_size = input_size;
|
|
this->filter_size = filter_size;
|
|
conv_stride = stride;
|
|
|
|
padding.n() = filter_size.h() / 2;
|
|
padding.h() = filter_size.h() / 2;
|
|
padding.w() = filter_size.w() / 2;
|
|
padding.c() = filter_size.w() / 2;
|
|
}
|
|
|
|
// Parses the command line
|
|
void parse(int argc, char const **args) {
|
|
cutlass::CommandLine cmd(argc, args);
|
|
|
|
if (cmd.check_cmd_line_flag("help")) {
|
|
help = true;
|
|
}
|
|
|
|
if (cmd.check_cmd_line_flag("ref-check")) {
|
|
reference_check = true;
|
|
}
|
|
|
|
if (cmd.check_cmd_line_flag("perf-check")) {
|
|
measure_performance = true;
|
|
}
|
|
|
|
if (cmd.check_cmd_line_flag("save-workspace")) {
|
|
save_workspace = true;
|
|
}
|
|
|
|
if (cmd.check_cmd_line_flag("benchmark")) {
|
|
benchmark = true;
|
|
}
|
|
|
|
cmd.get_cmd_line_argument("n", input_size.n());
|
|
cmd.get_cmd_line_argument("h", input_size.h());
|
|
cmd.get_cmd_line_argument("w", input_size.w());
|
|
cmd.get_cmd_line_argument("c", input_size.c());
|
|
|
|
cmd.get_cmd_line_argument("k", filter_size.n());
|
|
cmd.get_cmd_line_argument("r", filter_size.h());
|
|
cmd.get_cmd_line_argument("s", filter_size.w());
|
|
filter_size.c() = input_size.c();
|
|
|
|
cmd.get_cmd_line_argument("alpha", alpha);
|
|
cmd.get_cmd_line_argument("beta", beta);
|
|
cmd.get_cmd_line_argument("split-k-slices", split_k_slices);
|
|
|
|
cmd.get_cmd_line_argument("iterations", iterations);
|
|
cmd.get_cmd_line_argument("tag", tag);
|
|
|
|
if (filter_size.h() == 3 && filter_size.w() == 3) {
|
|
padding = {1, 1, 1, 1};
|
|
}
|
|
else {
|
|
filter_size.h() = 1;
|
|
filter_size.w() = 1;
|
|
padding = {0, 0, 0, 0};
|
|
}
|
|
}
|
|
|
|
/// Prints the usage statement.
|
|
std::ostream & print_usage(std::ostream &out) const {
|
|
|
|
out << "30_wgrad_split_k example\n\n"
|
|
<< " This example shows how to compute conv2d gradient with respect to weight (wgrad).\n"
|
|
<< " In wgrad, the K dimension of impligit GEMM, corresponding to the sequential reduction loop, is very large (N * P * Q).\n"
|
|
<< " Split-k with parallel reduction is highly effective for such cases.\n\n"
|
|
<< "Options:\n\n"
|
|
<< " --help If specified, displays this usage statement.\n\n"
|
|
<< " --n=<int> Input tensor extent N\n"
|
|
<< " --h=<int> Input tensor extent H\n"
|
|
<< " --w=<int> Input tensor extent W\n"
|
|
<< " --c=<int> Input tensor extent C\n"
|
|
<< " --k=<int> Filter extent K\n"
|
|
<< " --r=<int> Filter extent R\n"
|
|
<< " --s=<int> Filter extent S\n\n"
|
|
<< " --alpha=<float> Epilogue scalar alpha\n"
|
|
<< " --beta=<float> Epilogue scalar beta\n\n"
|
|
<< " --split-k-slices=<int> Split-k factor \n\n"
|
|
<< " --ref-check If set (true), reference check on the host is computed\n"
|
|
<< " --perf-check If set (true), performance is measured.\n"
|
|
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
|
|
<< " --iterations=<int> Number of profiling iterations to perform.\n"
|
|
<< " --save-workspace If set, workspace is written to a text file.\n"
|
|
<< " --tag=<string> String to replicate across the first column in the results table\n";
|
|
|
|
out << "\n\nExamples:\n\n"
|
|
<< "$ ./examples/30_wgrad_split_k/30_wgrad_split_k --n=32 --h=224 --w=224 --c=128 --k=256 --r=3 --s=3 --split-k-slices=8\n\n";
|
|
|
|
return out;
|
|
}
|
|
|
|
/// Computes the output tensor size (NPQK)
|
|
cutlass::Tensor4DCoord output_size() const {
|
|
return cutlass::Tensor4DCoord(input_size.n(),
|
|
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
|
|
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
|
|
filter_size.n());
|
|
}
|
|
|
|
/// Compute performance in GFLOP/s
|
|
double gflops(double runtime_s) const {
|
|
|
|
// Number of multiply-adds = NPQK * CRS
|
|
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
|
|
|
|
// Two flops per multiply-add
|
|
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct Result {
|
|
double runtime_ms;
|
|
double gflops;
|
|
cutlass::Status status;
|
|
cutlass::Status reference_check;
|
|
cudaError_t error;
|
|
|
|
Result():
|
|
runtime_ms(0),
|
|
gflops(0),
|
|
status(cutlass::Status::kSuccess),
|
|
reference_check(cutlass::Status::kInvalid),
|
|
error(cudaSuccess) { }
|
|
|
|
static std::ostream & print_header(std::ostream &out, Options const &options) {
|
|
|
|
if (!options.tag.empty()) {
|
|
out << "Name,";
|
|
}
|
|
|
|
out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs";
|
|
|
|
return out;
|
|
}
|
|
|
|
std::ostream & print(std::ostream &out, int idx, Options const &options) {
|
|
|
|
if (!options.tag.empty()) {
|
|
out << options.tag << ",";
|
|
}
|
|
|
|
out
|
|
<< "conv_" << idx << ","
|
|
<< options.input_size.n() << ","
|
|
<< options.input_size.h() << ","
|
|
<< options.input_size.w() << ","
|
|
<< options.input_size.c() << ","
|
|
<< options.filter_size.n() << ","
|
|
<< options.filter_size.h() << ","
|
|
<< options.filter_size.w() << ","
|
|
<< options.conv_stride.row() << ","
|
|
<< options.conv_stride.column() << ","
|
|
<< runtime_ms << ","
|
|
<< gflops;
|
|
|
|
return out;
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Runs one benchmark
|
|
Result profile_convolution(Options const &options) {
|
|
|
|
Result result;
|
|
|
|
//
|
|
// Allocate host-device tensors using the CUTLASS Utilities.
|
|
//
|
|
|
|
// Inputs are the output gradient and the original activation.
|
|
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.output_size());
|
|
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.input_size);
|
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.filter_size);
|
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.filter_size);
|
|
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.filter_size);
|
|
|
|
//
|
|
// Initialize tensors
|
|
//
|
|
|
|
// Fill tensor A on host with uniform-distribution random data
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
tensor_a.host_view(),
|
|
1,
|
|
ElementInputA(7),
|
|
ElementInputA(-8),
|
|
0);
|
|
|
|
// Fill tensor B on host with uniform-distribution random data
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
tensor_b.host_view(),
|
|
1,
|
|
ElementInputB(7),
|
|
ElementInputB(-8),
|
|
0);
|
|
|
|
// Fill tensor C, D on host with zeros
|
|
cutlass::reference::host::TensorFill(tensor_c.host_view());
|
|
|
|
cutlass::reference::host::TensorFill(tensor_d.host_view());
|
|
|
|
// Fill tensor D for reference on host with zeros
|
|
cutlass::reference::host::TensorFill(tensor_ref_d.host_view());
|
|
|
|
// Copy data from host to GPU
|
|
tensor_a.sync_device();
|
|
tensor_b.sync_device();
|
|
tensor_c.sync_device();
|
|
tensor_d.sync_device();
|
|
tensor_ref_d.sync_device();
|
|
|
|
//
|
|
// Define arguments for CUTLASS Convolution
|
|
//
|
|
|
|
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
|
|
|
|
// Partition the GEMM K loop into split_k_slices chunks
|
|
int split_k_slices = options.split_k_slices;
|
|
|
|
// Construct Conv2dProblemSize with user defined output size
|
|
// Do not forget to pass the last argument.
|
|
cutlass::conv::Conv2dProblemSize problem_size(
|
|
options.input_size,
|
|
options.filter_size,
|
|
options.padding,
|
|
options.conv_stride,
|
|
options.dilation,
|
|
options.output_size(),
|
|
mode,
|
|
split_k_slices
|
|
);
|
|
|
|
using cutlass::layout::TensorNHWC;
|
|
|
|
cutlass::conv::SplitKMode const split_k_mode = cutlass::conv::SplitKMode::kParallel;
|
|
|
|
// Since the epilogue is not computed after GEMM, there is no need to pass the C tensor and
|
|
// alpha and beta can be set to 1 and 0 respectively.
|
|
// Moreover, since the output will be written to the workspace, there is no need to pass
|
|
// the D tensor as well.
|
|
// Do not forget to pass the last argument.
|
|
typename ImplicitGemm::Arguments arguments{
|
|
problem_size,
|
|
tensor_a.device_ref(),
|
|
tensor_b.device_ref(),
|
|
{nullptr, TensorNHWC()},
|
|
{nullptr, TensorNHWC()},
|
|
{ElementCompute(1), ElementCompute(0)},
|
|
split_k_mode
|
|
};
|
|
|
|
//
|
|
// Initialize CUTLASS Convolution
|
|
//
|
|
|
|
ImplicitGemm implicit_gemm;
|
|
|
|
size_t workspace_size = implicit_gemm.get_workspace_size(arguments);
|
|
|
|
// Split-K requires non-zero workspace size. The workspace size grows linearly with split_k_slices.
|
|
std::cout << "split-k-slices: " << split_k_slices << std::endl;
|
|
std::cout << "workspace size: " << workspace_size << std::endl;
|
|
|
|
// Allocate workspace memory
|
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
|
|
|
result.status = implicit_gemm.can_implement(arguments);
|
|
CUTLASS_CHECK(result.status);
|
|
|
|
// After the workspace is allocated, we point the GEMM destination pointer to the workspace.
|
|
TensorNHWC layout_D{TensorNHWC::packed(options.filter_size)};
|
|
arguments.ref_D.reset(reinterpret_cast<ElementCompute*>(workspace.get()), layout_D);
|
|
|
|
result.status = implicit_gemm.initialize(arguments, workspace.get());
|
|
CUTLASS_CHECK(result.status);
|
|
|
|
//
|
|
// Launch initialized CUTLASS kernel
|
|
//
|
|
result.status = implicit_gemm();
|
|
|
|
CUTLASS_CHECK(result.status);
|
|
|
|
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
|
|
// Do reduction
|
|
ReductionDevice reduction_op;
|
|
auto& status = result.status;
|
|
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemm::kConvolutionalOperator;
|
|
typename ReductionDevice::Arguments reduction_args(
|
|
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
|
|
problem_size.split_k_slices,
|
|
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
|
|
// Reduction input
|
|
{
|
|
reinterpret_cast<ElementAccumulator*> (workspace.get()),
|
|
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
|
|
},
|
|
// Destination
|
|
{
|
|
tensor_d.device_data(),
|
|
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
|
|
},
|
|
// Source
|
|
{
|
|
tensor_c.device_data(),
|
|
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
|
|
},
|
|
{options.alpha, options.beta}
|
|
);
|
|
|
|
status = reduction_op.initialize(reduction_args, nullptr);
|
|
status = reduction_op();
|
|
}
|
|
|
|
//
|
|
// Optional reference check
|
|
//
|
|
|
|
if (options.reference_check) {
|
|
std::cout << "Verification on device...\n";
|
|
|
|
// Compute with reference implementation
|
|
cutlass::reference::device::Conv2dWgrad<
|
|
ElementInputA,
|
|
LayoutInputA,
|
|
ElementInputB,
|
|
LayoutInputB,
|
|
ElementOutput,
|
|
LayoutOutput,
|
|
ElementComputeEpilogue,
|
|
ElementAccumulator,
|
|
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
|
|
>(
|
|
problem_size,
|
|
tensor_a.device_ref(),
|
|
tensor_b.device_ref(),
|
|
tensor_c.device_ref(),
|
|
tensor_ref_d.device_ref(),
|
|
options.alpha,
|
|
options.beta
|
|
);
|
|
|
|
// Check if output from CUTLASS kernel and reference kernel are equal or not
|
|
tensor_c.sync_host();
|
|
tensor_d.sync_host();
|
|
tensor_ref_d.sync_host();
|
|
|
|
bool passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view());
|
|
|
|
if (!passed) {
|
|
result.reference_check = cutlass::Status::kErrorInternal;
|
|
std::cout << "ERROR - results miscompared.\n";
|
|
}
|
|
else {
|
|
result.reference_check = cutlass::Status::kSuccess;
|
|
std::cout << "Passed.\n";
|
|
}
|
|
}
|
|
else {
|
|
result.reference_check = cutlass::Status::kInvalid;
|
|
}
|
|
|
|
if (options.save_workspace) {
|
|
|
|
std::stringstream ss;
|
|
|
|
ss << "30_wgrad_split_k_"
|
|
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
|
|
<< "_"
|
|
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
|
|
<< ".dat";
|
|
|
|
std::ofstream output_workspace(ss.str());
|
|
|
|
output_workspace
|
|
<< "Input = \n" << tensor_a.host_view() << "\n\n"
|
|
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
|
|
|
|
if (options.reference_check) {
|
|
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
|
|
}
|
|
|
|
output_workspace << "Computed = \n" << tensor_c.host_view() << std::endl;
|
|
|
|
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
|
|
}
|
|
|
|
//
|
|
// Performance measurement
|
|
//
|
|
|
|
if (options.measure_performance) {
|
|
|
|
cudaEvent_t events[2];
|
|
|
|
for (auto & event : events) {
|
|
result.error = cudaEventCreate(&event);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
}
|
|
|
|
// Record an event at the start of a series of convolution operations.
|
|
result.error = cudaEventRecord(events[0]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Launch a sequence of implicit GEMM operations on the device
|
|
for (int iteration = 0; iteration < options.iterations; ++iteration) {
|
|
result.status = implicit_gemm();
|
|
CUTLASS_CHECK(result.status);
|
|
}
|
|
|
|
// Record an event when the convolutions have been launched.
|
|
result.error = cudaEventRecord(events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Wait for work on the device to complete.
|
|
result.error = cudaEventSynchronize(events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Measure elapsed runtime
|
|
float runtime_ms = 0;
|
|
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
|
|
if (result.error != cudaSuccess) {
|
|
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
|
|
return result;
|
|
}
|
|
|
|
// Print average runtime and GFLOPs.
|
|
result.runtime_ms = double(runtime_ms) / double(options.iterations);
|
|
result.gflops = options.gflops(result.runtime_ms / 1000.0);
|
|
|
|
// Cleanup
|
|
for (auto event : events) {
|
|
(void)cudaEventDestroy(event);
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
int main(int argc, char const **args) {
|
|
bool notSupported = false;
|
|
|
|
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
|
|
//
|
|
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
|
|
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
|
|
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
|
|
notSupported = true;
|
|
}
|
|
|
|
cudaDeviceProp props;
|
|
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
|
|
|
|
if (!(props.major >= 8)) {
|
|
std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80."
|
|
<< std::endl;
|
|
notSupported = true;
|
|
}
|
|
|
|
if (notSupported) {
|
|
return 0;
|
|
}
|
|
|
|
Options options;
|
|
|
|
options.parse(argc, args);
|
|
|
|
if (options.help) {
|
|
options.print_usage(std::cout) << std::endl;
|
|
return 0;
|
|
}
|
|
|
|
if (options.benchmark) {
|
|
// Benchmark several layers
|
|
|
|
int batch_sizes[] = {34, 408};
|
|
|
|
struct Benchmark {
|
|
int h, w, c, k, r, s, stride_h, stride_w;
|
|
} layers[] = {
|
|
{56, 56, 64, 256, 1, 1, 1, 1},
|
|
{56, 56, 64, 64, 1, 1, 1, 1},
|
|
{56, 56, 64, 64, 3, 3, 1, 1},
|
|
{56, 56, 256, 64, 1, 1, 1, 1},
|
|
{56, 56, 256, 512, 1, 1, 2, 2},
|
|
{56, 56, 256, 128, 1, 1, 1, 1},
|
|
{56, 56, 128, 128, 3, 3, 2, 2},
|
|
{28, 28, 128, 512, 1, 1, 1, 1},
|
|
{28, 28, 512, 128, 1, 1, 1, 1},
|
|
{28, 28, 128, 128, 3, 3, 1, 1},
|
|
{28, 28, 512, 1024, 1, 1, 2, 2},
|
|
{28, 28, 512, 256, 1, 1, 1, 1},
|
|
{28, 28, 256, 256, 3, 3, 2, 2},
|
|
{14, 14, 256, 1024, 1, 1, 1, 1},
|
|
{14, 14, 1024, 256, 1, 1, 1, 1},
|
|
{14, 14, 256, 256, 3, 3, 1, 1},
|
|
{14, 14, 1024, 2048, 1, 1, 2, 2},
|
|
{14, 14, 1024, 512, 1, 1, 1, 1},
|
|
{14, 14, 512, 512, 3, 3, 2, 2},
|
|
{ 7, 7, 512, 2048, 1, 1, 1, 1},
|
|
{ 7, 7, 2048, 512, 1, 1, 1, 1},
|
|
{ 7, 7, 512, 512, 3, 3, 1, 1},
|
|
};
|
|
|
|
Result::print_header(std::cout, options) << std::endl;
|
|
|
|
int idx = 1;
|
|
|
|
for (auto const &layer : layers) {
|
|
for (auto N : batch_sizes) {
|
|
options.update({N, layer.h, layer.w, layer.c},
|
|
{layer.k, layer.r, layer.s, layer.c},
|
|
{layer.stride_h, layer.stride_w});
|
|
|
|
Result result = profile_convolution(options);
|
|
result.print(std::cout, idx, options) << std::endl;
|
|
}
|
|
|
|
++idx;
|
|
}
|
|
}
|
|
else {
|
|
|
|
// Execute one problem size
|
|
if (!options.valid()) {
|
|
std::cerr << "Invalid problem." << std::endl;
|
|
return -1;
|
|
}
|
|
|
|
Result result = profile_convolution(options);
|
|
|
|
Result::print_header(std::cout, options) << std::endl;
|
|
result.print(std::cout, 1, options) << std::endl;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|