939 lines
31 KiB
C++
939 lines
31 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <fstream>
|
|
#include <sstream>
|
|
#include <type_traits>
|
|
|
|
#include "cutlass/util/host_tensor.h"
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
#include "cutlass/util/distribution.h"
|
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
|
#include "cutlass/util/reference/host/tensor_norm.h"
|
|
#include "cutlass/util/reference/device/gemm.h"
|
|
#include "cutlass/util/reference/device/tensor_relu.h"
|
|
|
|
#include "cutlass/platform/platform.h"
|
|
#include "cutlass/gemm/gemm.h"
|
|
#include "cutlass/gemm/device/gemm_universal.h"
|
|
|
|
#include "dual_gemm_common.h"
|
|
#include "helper.h"
|
|
|
|
#define CHECK_GT(val1, val2) \
|
|
if((val1) <= (val2)) \
|
|
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
|
|
#define CHECK_TRUE(val) \
|
|
if(!(val)) \
|
|
std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
|
|
|
|
template <
|
|
typename OutputOp,
|
|
typename Element,
|
|
typename Layout>
|
|
struct TensorEpilogueForEachFunc {
|
|
/// View type
|
|
using TensorView = cutlass::TensorView<Element, Layout>;
|
|
|
|
/// Coordinate in tensor's index space
|
|
using TensorCoord = typename TensorView::TensorCoord;
|
|
|
|
/// Parameters structure
|
|
struct Params {
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
TensorView view_x0;
|
|
TensorView view_x1;
|
|
TensorView view_y;
|
|
OutputOp output_op;
|
|
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
Params(
|
|
TensorView view_x0_ = TensorView(),
|
|
TensorView view_x1_ = TensorView(),
|
|
TensorView view_y_ = TensorView(),
|
|
OutputOp output_op_ = OutputOp(typename OutputOp::Params{})
|
|
):
|
|
view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) {
|
|
}
|
|
};
|
|
|
|
Params params;
|
|
|
|
CUTLASS_DEVICE
|
|
TensorEpilogueForEachFunc(Params const ¶ms): params(params) {
|
|
|
|
}
|
|
|
|
CUTLASS_DEVICE
|
|
void operator()(TensorCoord const &coord) {
|
|
Element const & x0 = params.view_x0.at(coord);
|
|
Element const & x1 = params.view_x1.at(coord);
|
|
Element& y = params.view_y.at(coord);
|
|
y = params.output_op(x0, x1);
|
|
}
|
|
};
|
|
|
|
template <
|
|
typename OutputOp,
|
|
typename Element,
|
|
typename Layout>
|
|
void TensorEpilogueForEach(
|
|
cutlass::TensorView<Element, Layout> x0,
|
|
cutlass::TensorView<Element, Layout> x1,
|
|
cutlass::TensorView<Element, Layout> y) {
|
|
|
|
using Func = TensorEpilogueForEachFunc<OutputOp, Element, Layout>;
|
|
using Params = typename Func::Params;
|
|
|
|
cutlass::reference::device::TensorForEach<Func, Layout::kRank, Params>(
|
|
y.extent(),
|
|
Params(x0, x1, y)
|
|
);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Gemm0_, typename Gemm1_>
|
|
struct NonFusedDualGemmRun
|
|
{
|
|
|
|
using Gemm0 = Gemm0_;
|
|
using Gemm1 = Gemm1_;
|
|
using ElementAccumulator = typename Gemm0::ElementAccumulator;
|
|
using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
|
|
|
|
/// Initialization
|
|
cutlass::Distribution::Kind init_A;
|
|
cutlass::Distribution::Kind init_B;
|
|
cutlass::Distribution::Kind init_C;
|
|
cutlass::Distribution::Kind init_Bias;
|
|
uint64_t seed;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
NonFusedDualGemmRun(
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
|
uint64_t seed_ = 2080
|
|
):
|
|
init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { }
|
|
|
|
/// Helper to initialize a tensor view
|
|
template <typename Element, typename Layout>
|
|
bool initialize_tensor(
|
|
cutlass::TensorView<Element, Layout> view,
|
|
cutlass::Distribution::Kind dist_kind,
|
|
uint64_t seed) {
|
|
|
|
if (dist_kind == cutlass::Distribution::Uniform) {
|
|
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
view, seed, 2, -2, 0);
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::Identity) {
|
|
|
|
cutlass::reference::host::TensorFillIdentity(view);
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
|
|
|
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::Sequential) {
|
|
|
|
cutlass::reference::host::BlockFillSequential(
|
|
view.data(), view.capacity());
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
|
cutlass::reference::host::TensorFill(view, Element(0));
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
|
cutlass::reference::host::TensorFill(view, Element(1));
|
|
}
|
|
else {
|
|
std::cerr << "Not implemented\n";
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Executes one test
|
|
bool run(
|
|
cutlass::gemm::GemmCoord problem_size,
|
|
ElementCompute alpha0 = ElementCompute(1),
|
|
ElementCompute beta0 = ElementCompute(0),
|
|
ElementCompute alpha1 = ElementCompute(1),
|
|
ElementCompute beta1 = ElementCompute(0),
|
|
bool is_profiling = true,
|
|
bool relu = false,
|
|
int warm_ups = 1,
|
|
int runs = 100) {
|
|
|
|
//
|
|
// Allocate the GEMM workspace
|
|
//
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm0::ElementA,
|
|
typename Gemm0::LayoutA> tensor_A0(problem_size.mk());
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm0::ElementB,
|
|
typename Gemm0::LayoutB> tensor_B0(problem_size.kn());
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm0::ElementC,
|
|
typename Gemm0::LayoutC> tensor_C0(problem_size.mn());
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm1::ElementC,
|
|
typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()});
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm0::ElementC,
|
|
typename Gemm0::LayoutC> tensor_D0(problem_size.mn());
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm0::ElementC,
|
|
typename Gemm0::LayoutC> reference_D0(problem_size.mn());
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm1::ElementB,
|
|
typename Gemm1::LayoutB> tensor_B1(problem_size.kn());
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm1::ElementC,
|
|
typename Gemm1::LayoutC> tensor_C1(problem_size.mn());
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm1::ElementC,
|
|
typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()});
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm1::ElementC,
|
|
typename Gemm1::LayoutC> tensor_D1(problem_size.mn());
|
|
|
|
cutlass::HostTensor<
|
|
typename Gemm1::ElementC,
|
|
typename Gemm1::LayoutC> reference_D1(problem_size.mn());
|
|
|
|
|
|
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
|
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
|
|
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
|
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014));
|
|
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
|
|
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
|
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013));
|
|
|
|
cutlass::reference::host::TensorFill(
|
|
tensor_D0.host_view());
|
|
cutlass::reference::host::TensorFill(
|
|
tensor_D1.host_view());
|
|
cutlass::reference::host::TensorFill(
|
|
reference_D0.host_view());
|
|
cutlass::reference::host::TensorFill(
|
|
reference_D1.host_view());
|
|
|
|
tensor_A0.sync_device();
|
|
tensor_B0.sync_device();
|
|
tensor_C0.sync_device();
|
|
tensor_Bias0.sync_device();
|
|
tensor_D0.sync_device();
|
|
reference_D0.sync_device();
|
|
tensor_B1.sync_device();
|
|
tensor_C1.sync_device();
|
|
tensor_Bias1.sync_device();
|
|
tensor_D1.sync_device();
|
|
reference_D1.sync_device();
|
|
|
|
//
|
|
// Initialize the GEMM operator
|
|
//
|
|
|
|
int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1;
|
|
typename Gemm0::Arguments arguments_0{
|
|
problem_size,
|
|
tensor_A0.device_ref(),
|
|
tensor_B0.device_ref(),
|
|
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
|
tensor_D0.device_ref(),
|
|
{alpha0, beta0},
|
|
split_k_slices
|
|
};
|
|
|
|
split_k_slices = Gemm1::kSplitKSerial ? 2 : 1;
|
|
typename Gemm1::Arguments arguments_1{
|
|
problem_size,
|
|
tensor_A0.device_ref(),
|
|
tensor_B1.device_ref(),
|
|
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
|
tensor_D1.device_ref(),
|
|
{alpha1, beta1},
|
|
split_k_slices
|
|
};
|
|
|
|
|
|
Gemm0 gemm_op_0;
|
|
Gemm1 gemm_op_1;
|
|
|
|
// Allocate workspace memory
|
|
cutlass::device_memory::allocation<uint8_t> workspace0(gemm_op_0.get_workspace_size(arguments_0));
|
|
cutlass::device_memory::allocation<uint8_t> workspace1(gemm_op_1.get_workspace_size(arguments_1));
|
|
|
|
cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get());
|
|
|
|
CUTLASS_CHECK(status);
|
|
|
|
status = gemm_op_1.initialize(arguments_1, workspace1.get());
|
|
|
|
CUTLASS_CHECK(status);
|
|
|
|
for(int i = 0; i < warm_ups; i++) {
|
|
status = gemm_op_0();
|
|
CUTLASS_CHECK(status);
|
|
status = gemm_op_1();
|
|
CUTLASS_CHECK(status);
|
|
}
|
|
|
|
if (is_profiling) {
|
|
//
|
|
// Profile the GEMM
|
|
//
|
|
|
|
cudaEvent_t start, stop1, stop2;
|
|
cudaEventCreate(&start);
|
|
cudaEventCreate(&stop1);
|
|
cudaEventCreate(&stop2);
|
|
|
|
cudaEventRecord(start);
|
|
|
|
for(int i = 0; i < runs; i++) {
|
|
status = gemm_op_0();
|
|
|
|
CUTLASS_CHECK(status);
|
|
}
|
|
cudaEventRecord(stop1);
|
|
for(int i = 0; i < runs; i++) {
|
|
status = gemm_op_1();
|
|
|
|
CUTLASS_CHECK(status);
|
|
}
|
|
|
|
cudaEventRecord(stop2);
|
|
cudaDeviceSynchronize();
|
|
float gemm0Time, gemm1Time, totalTime;
|
|
cudaEventElapsedTime(&gemm0Time, start, stop1);
|
|
cudaEventElapsedTime(&gemm1Time, stop1, stop2);
|
|
cudaEventElapsedTime(&totalTime, start, stop2);
|
|
std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n";
|
|
std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n";
|
|
std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n";
|
|
}
|
|
|
|
tensor_D0.sync_host();
|
|
tensor_D1.sync_host();
|
|
|
|
//
|
|
// Verify
|
|
//
|
|
cutlass::reference::device::Gemm<
|
|
typename Gemm0::ElementA, typename Gemm0::LayoutA,
|
|
typename Gemm0::ElementB, typename Gemm0::LayoutB,
|
|
typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
|
|
ElementAccumulator, typename Gemm0::Operator>
|
|
reference_gemm_0;
|
|
|
|
cutlass::reference::device::Gemm<
|
|
typename Gemm1::ElementA, typename Gemm1::LayoutA,
|
|
typename Gemm1::ElementB, typename Gemm1::LayoutB,
|
|
typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
|
|
ElementAccumulator, typename Gemm1::Operator>
|
|
reference_gemm_1;
|
|
|
|
reference_gemm_0(
|
|
problem_size,
|
|
alpha0,
|
|
tensor_A0.device_ref(),
|
|
tensor_B0.device_ref(),
|
|
beta0,
|
|
{tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)},
|
|
reference_D0.device_ref()
|
|
);
|
|
|
|
if(relu) {
|
|
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
|
}
|
|
|
|
reference_gemm_1(
|
|
problem_size,
|
|
alpha1,
|
|
tensor_A0.device_ref(),
|
|
tensor_B1.device_ref(),
|
|
beta1,
|
|
{tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)},
|
|
reference_D1.device_ref()
|
|
);
|
|
|
|
if(relu) {
|
|
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
|
}
|
|
|
|
// Wait for kernels to finish
|
|
cudaDeviceSynchronize();
|
|
reference_D0.sync_host();
|
|
reference_D1.sync_host();
|
|
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
|
|
|
bool passed0 = cutlass::reference::host::TensorEquals(
|
|
reference_D1.host_view(),
|
|
tensor_D1.host_view());
|
|
CHECK_TRUE(passed0);
|
|
|
|
bool passed1 = cutlass::reference::host::TensorEquals(
|
|
reference_D1.host_view(),
|
|
tensor_D1.host_view());
|
|
CHECK_TRUE(passed1);
|
|
if (!passed0 || !passed1) {
|
|
|
|
std::stringstream fname;
|
|
|
|
fname << "error_DualGemm_device_nonfused.txt";
|
|
std::cerr << "Dumping results in " << fname.str() << "\n";
|
|
|
|
std::ofstream file(fname.str());
|
|
|
|
file
|
|
<< "A0 =\n" << tensor_A0.host_view()
|
|
<< "\nB0 =\n" << tensor_B0.host_view()
|
|
<< "\nC0 =\n" << tensor_C0.host_view()
|
|
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
|
<< "\nD0 =\n" << tensor_D0.host_view()
|
|
<< "\nB1 =\n" << tensor_B1.host_view()
|
|
<< "\nC1 =\n" << tensor_C1.host_view()
|
|
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
|
<< "\n\nReference =\n" << reference_D1.host_view()
|
|
<< "\nComputed =\n" << tensor_D1.host_view();
|
|
}
|
|
return passed0 && passed1;
|
|
}
|
|
};
|
|
|
|
template <typename DualGemm_>
|
|
struct DualFusedGemmRun
|
|
{
|
|
|
|
using DualGemm = DualGemm_;
|
|
using ElementAccumulator = typename DualGemm::ElementAccumulator;
|
|
using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute;
|
|
using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2;
|
|
|
|
/// Initialization
|
|
cutlass::Distribution::Kind init_A;
|
|
cutlass::Distribution::Kind init_B;
|
|
cutlass::Distribution::Kind init_C;
|
|
cutlass::Distribution::Kind init_Scale;
|
|
cutlass::Distribution::Kind init_Bias;
|
|
uint64_t seed;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
DualFusedGemmRun(
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform,
|
|
uint64_t seed_ = 2080
|
|
):
|
|
init_A(init_A_), init_B(init_B_), init_C(init_C_),
|
|
init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { }
|
|
|
|
/// Helper to initialize a tensor view
|
|
template <typename Element, typename Layout>
|
|
bool initialize_tensor(
|
|
cutlass::TensorView<Element, Layout> view,
|
|
cutlass::Distribution::Kind dist_kind,
|
|
uint64_t seed) {
|
|
|
|
if (dist_kind == cutlass::Distribution::Uniform) {
|
|
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
view, seed, 2, -2, 0);
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::Identity) {
|
|
|
|
cutlass::reference::host::TensorFillIdentity(view);
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
|
|
|
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::Sequential) {
|
|
|
|
cutlass::reference::host::BlockFillSequential(
|
|
view.data(), view.capacity());
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::AllZeros) {
|
|
cutlass::reference::host::TensorFill(view, Element(0));
|
|
}
|
|
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
|
cutlass::reference::host::TensorFill(view, Element(1));
|
|
}
|
|
else {
|
|
std::cerr << "Not implemented\n";
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Executes one test
|
|
bool run(
|
|
cutlass::gemm::GemmCoord problem_size,
|
|
ElementCompute alpha0 = ElementCompute(1),
|
|
ElementCompute beta0 = ElementCompute(1),
|
|
ElementCompute alpha1 = ElementCompute(1),
|
|
ElementCompute beta1 = ElementCompute(1),
|
|
int batch_count = 1,
|
|
bool broadcast_b1 = false,
|
|
bool is_profiling = true,
|
|
bool relu = false,
|
|
int warm_ups = 1,
|
|
int runs = 100) {
|
|
|
|
//
|
|
// Allocate the GEMM workspace
|
|
//
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementA,
|
|
typename DualGemm::LayoutA> tensor_A0(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutA, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) :
|
|
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k()));
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementB,
|
|
typename DualGemm::LayoutB0> tensor_B0(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutB0, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutC> tensor_C0(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()});
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutC> tensor_D0(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutC> reference_D0(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementB,
|
|
typename DualGemm::LayoutB1> tensor_B1(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutB1, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n()));
|
|
if (broadcast_b1) {
|
|
tensor_B1.resize({problem_size.k(), batch_count});
|
|
}
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutC> tensor_C1(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()});
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutC> tensor_D1(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutC> tensor_D2(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutC> reference_D1(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
|
|
|
cutlass::HostTensor<
|
|
typename DualGemm::ElementC,
|
|
typename DualGemm::LayoutC> reference_D2(
|
|
cutlass::platform::is_same<typename DualGemm::LayoutC, cutlass::layout::RowMajor>::value ?
|
|
cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) :
|
|
cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n()));
|
|
|
|
CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
|
|
CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118));
|
|
CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
|
|
CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011));
|
|
CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113));
|
|
CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
|
|
CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012));
|
|
|
|
cutlass::reference::host::TensorFill(
|
|
tensor_D0.host_view());
|
|
cutlass::reference::host::TensorFill(
|
|
tensor_D1.host_view());
|
|
cutlass::reference::host::TensorFill(
|
|
tensor_D2.host_view());
|
|
cutlass::reference::host::TensorFill(
|
|
reference_D0.host_view());
|
|
cutlass::reference::host::TensorFill(
|
|
reference_D1.host_view());
|
|
cutlass::reference::host::TensorFill(
|
|
reference_D2.host_view());
|
|
|
|
tensor_A0.sync_device();
|
|
tensor_B0.sync_device();
|
|
tensor_C0.sync_device();
|
|
tensor_Bias0.sync_device();
|
|
tensor_B1.sync_device();
|
|
tensor_C1.sync_device();
|
|
tensor_Bias1.sync_device();
|
|
tensor_D0.sync_device();
|
|
tensor_D1.sync_device();
|
|
tensor_D2.sync_device();
|
|
reference_D0.sync_device();
|
|
reference_D1.sync_device();
|
|
reference_D2.sync_device();
|
|
|
|
//
|
|
// Batch strides (irrelevant when batch_count == 1)
|
|
//
|
|
|
|
int64_t batch_stride_A = problem_size.m() * problem_size.k();
|
|
int64_t batch_stride_B0 = problem_size.k() * problem_size.n();
|
|
int64_t batch_stride_B1 = problem_size.k() * problem_size.n();
|
|
if (broadcast_b1) {
|
|
// B1 is a (column) vector
|
|
batch_stride_B1 = problem_size.k();
|
|
}
|
|
int64_t batch_stride_Bias = problem_size.n();
|
|
int64_t batch_stride_D = problem_size.m() * problem_size.n();
|
|
|
|
//
|
|
// Initialize the GEMM operator
|
|
//
|
|
|
|
int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;
|
|
typename cutlass::TensorRef<typename DualGemm::ElementC, typename DualGemm::LayoutC> nullptr_ref{};
|
|
decltype(nullptr_ref) ref_B0, ref_B1;
|
|
if (beta0 != ElementCompute(0)) {
|
|
ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)};
|
|
}
|
|
if (beta1 != ElementCompute(0)) {
|
|
ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)};
|
|
}
|
|
typename DualGemm::Arguments arguments{
|
|
(batch_count > 1 ?
|
|
cutlass::gemm::DualGemmMode::kBatched :
|
|
cutlass::gemm::DualGemmMode::kGemm),
|
|
problem_size,
|
|
tensor_A0.device_ref(),
|
|
tensor_B0.device_ref(),
|
|
ref_B0,
|
|
DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref,
|
|
(broadcast_b1 ?
|
|
typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) :
|
|
tensor_B1.device_ref()),
|
|
ref_B1,
|
|
DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref,
|
|
tensor_D2.device_ref(),
|
|
{alpha0, beta0},
|
|
{alpha1, beta1},
|
|
{},
|
|
split_k_slices,
|
|
batch_count,
|
|
batch_stride_A,
|
|
batch_stride_B0,
|
|
batch_stride_B1,
|
|
batch_stride_Bias,
|
|
batch_stride_D,
|
|
};
|
|
|
|
//
|
|
// Run the GEMM
|
|
//
|
|
|
|
DualGemm b2b_gemm_op;
|
|
|
|
cutlass::device_memory::allocation<uint8_t> workspace(b2b_gemm_op.get_workspace_size(arguments));
|
|
|
|
cutlass::Status status = b2b_gemm_op.can_implement(arguments);
|
|
|
|
CUTLASS_CHECK(status);
|
|
|
|
status = b2b_gemm_op.initialize(arguments, workspace.get());
|
|
|
|
CUTLASS_CHECK(status);
|
|
|
|
for(int i = 0; i < warm_ups; i++) {
|
|
status = b2b_gemm_op();
|
|
CUTLASS_CHECK(status);
|
|
}
|
|
|
|
if (is_profiling) {
|
|
//
|
|
// Profile the GEMM
|
|
//
|
|
|
|
cudaEvent_t start, stop;
|
|
cudaEventCreate(&start);
|
|
cudaEventCreate(&stop);
|
|
|
|
cudaEventRecord(start);
|
|
|
|
for(int i = 0; i < runs; i++) {
|
|
status = b2b_gemm_op();
|
|
CUTLASS_CHECK(status);
|
|
}
|
|
|
|
cudaEventRecord(stop);
|
|
cudaDeviceSynchronize();
|
|
float gemmTime;
|
|
cudaEventElapsedTime(&gemmTime, start, stop);
|
|
std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n";
|
|
}
|
|
|
|
tensor_D0.sync_host();
|
|
tensor_D1.sync_host();
|
|
tensor_D2.sync_host();
|
|
|
|
//
|
|
// Verify
|
|
//
|
|
|
|
using GemmUniversal0 = cutlass::gemm::device::GemmUniversal<
|
|
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
|
typename DualGemm::ElementB, typename DualGemm::LayoutB0,
|
|
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
|
ElementAccumulator
|
|
>;
|
|
|
|
GemmUniversal0 reference_gemm0;
|
|
|
|
typename GemmUniversal0::Arguments args0 {
|
|
(batch_count > 1 ?
|
|
cutlass::gemm::GemmUniversalMode::kBatched :
|
|
cutlass::gemm::GemmUniversalMode::kGemm),
|
|
problem_size,
|
|
batch_count,
|
|
{alpha0, beta0},
|
|
tensor_A0.device_data(),
|
|
tensor_B0.device_data(),
|
|
tensor_Bias0.device_data(),
|
|
reference_D0.device_data(),
|
|
batch_stride_A,
|
|
batch_stride_B0,
|
|
batch_stride_Bias,
|
|
batch_stride_D,
|
|
tensor_A0.stride(0),
|
|
tensor_B0.stride(0),
|
|
0, // zero stride for the bias vector
|
|
reference_D0.stride(0),
|
|
};
|
|
|
|
status = reference_gemm0.can_implement(args0);
|
|
CUTLASS_CHECK(status);
|
|
status = reference_gemm0(args0);
|
|
CUTLASS_CHECK(status);
|
|
|
|
using GemmUniversal1 = cutlass::gemm::device::GemmUniversal<
|
|
typename DualGemm::ElementA, typename DualGemm::LayoutA,
|
|
typename DualGemm::ElementB, typename DualGemm::LayoutB1,
|
|
typename DualGemm::ElementC, typename DualGemm::LayoutC,
|
|
ElementAccumulator
|
|
>;
|
|
|
|
GemmUniversal1 reference_gemm1;
|
|
|
|
typename GemmUniversal1::Arguments args1 {
|
|
(batch_count > 1 ?
|
|
cutlass::gemm::GemmUniversalMode::kBatched :
|
|
cutlass::gemm::GemmUniversalMode::kGemm),
|
|
problem_size,
|
|
batch_count,
|
|
{alpha1, beta1},
|
|
tensor_A0.device_data(),
|
|
tensor_B1.device_data(),
|
|
tensor_Bias1.device_data(),
|
|
reference_D1.device_data(),
|
|
batch_stride_A,
|
|
batch_stride_B1,
|
|
batch_stride_Bias,
|
|
batch_stride_D,
|
|
tensor_A0.stride(0),
|
|
(broadcast_b1 ? 0 : tensor_B1.stride(0)),
|
|
0, // zero stride for the bias vector
|
|
reference_D1.stride(0),
|
|
};
|
|
|
|
status = reference_gemm1.can_implement(args1);
|
|
CUTLASS_CHECK(status);
|
|
status = reference_gemm1(args1);
|
|
CUTLASS_CHECK(status);
|
|
|
|
if(relu) {
|
|
cutlass::reference::device::TensorReLu(reference_D0.device_view());
|
|
cutlass::reference::device::TensorReLu(reference_D1.device_view());
|
|
}
|
|
|
|
TensorEpilogueForEach<EpilogueOutputOp2>(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view());
|
|
cudaDeviceSynchronize();
|
|
reference_D0.sync_host();
|
|
reference_D1.sync_host();
|
|
reference_D2.sync_host();
|
|
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0);
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0);
|
|
|
|
bool passed_out0 = true;
|
|
if (DualGemm::kStoreD0) {
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
|
|
passed_out0 = cutlass::reference::host::TensorEquals(
|
|
reference_D0.host_view(),
|
|
tensor_D0.host_view());
|
|
}
|
|
CHECK_TRUE(passed_out0);
|
|
|
|
bool passed_out1 = true;
|
|
if (DualGemm::kStoreD1) {
|
|
CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
|
|
passed_out1 = cutlass::reference::host::TensorEquals(
|
|
reference_D1.host_view(),
|
|
tensor_D1.host_view());
|
|
}
|
|
CHECK_TRUE(passed_out1);
|
|
|
|
bool passed_out2 = cutlass::reference::host::TensorEquals(
|
|
reference_D2.host_view(),
|
|
tensor_D2.host_view());
|
|
CHECK_TRUE(passed_out2);
|
|
|
|
bool passed = passed_out0 && passed_out1 && passed_out2;
|
|
if (!passed)
|
|
{
|
|
std::stringstream fname;
|
|
|
|
fname << "error_DualGemm_device_fused.txt";
|
|
std::cerr << "Dumping results in " << fname.str() << "\n";
|
|
|
|
std::ofstream file(fname.str());
|
|
|
|
file
|
|
<< "A0 =\n" << tensor_A0.host_view()
|
|
<< "\nB0 =\n" << tensor_B0.host_view()
|
|
<< "\nC0 =\n" << tensor_C0.host_view()
|
|
<< "\nBias0:\n" << tensor_Bias0.host_view() << "\n"
|
|
<< "\nB1 =\n" << tensor_B1.host_view()
|
|
<< "\nC1 =\n" << tensor_C1.host_view()
|
|
<< "\nBias1:\n" << tensor_Bias1.host_view() << "\n"
|
|
<< "\n\nReference0 =\n" << reference_D0.host_view()
|
|
<< "\nComputed0 =\n" << tensor_D0.host_view()
|
|
<< "\n\nReference1 =\n" << reference_D1.host_view()
|
|
<< "\nComputed1 =\n" << tensor_D1.host_view()
|
|
<< "\n\nReference2 =\n" << reference_D2.host_view()
|
|
<< "\nComputed2 =\n" << tensor_D2.host_view();
|
|
}
|
|
//std::cout << "A0 " << tensor_A0.host_view() << std::endl;
|
|
// std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
|
|
// std::cout << "reference_D1 " << reference_D1.host_view() << std::endl;
|
|
// std::cout << "reference_D2 " << reference_D2.host_view() << std::endl;
|
|
//std::cout << "reference_D0 " << reference_D0.host_view() << std::endl;
|
|
return passed;
|
|
}
|
|
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|