cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp

1457 lines
46 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Testbed and host reference for EVT unittest
*/
#pragma once
#include "gemm_testbed_3x.hpp"
namespace test {
namespace gemm {
namespace device {
/// Host-side tapply, tapply in cute is HOST_DEVICE
template <class T, class F, class G, int... I>
constexpr auto
tapply(T&& t, F&& f, G&& g, cute::seq<I...>)
{
return g(f(std::get<I>(static_cast<T&&>(t)))...);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT: Base class for EVT Node
template <
typename Gemm_
>
class HostEVTNodeBase {
public:
using Gemm = Gemm_;
using TestBedImpl = typename detail::TestbedImpl<Gemm, cutlass::epilogue::thread::Identity, true>;
using Kernel = typename Gemm::GemmKernel;
using Epilogue = typename Kernel::CollectiveEpilogue;
using ElementCompute = typename TestBedImpl::ElementCompute;
using ElementScalar = typename TestBedImpl::ElementScalar;
using ElementAccumulator = typename Kernel::ElementAccumulator;
using ElementC = typename Kernel::ElementC;
using ElementD = typename Kernel::ElementD;
using LayoutTagC = typename TestBedImpl::LayoutTagC;
using LayoutTagD = typename TestBedImpl::LayoutTagD;
private:
bool _check_relative_equality;
// Factors used for calculating relative equality. These default
// values are borrowed from those used by default in the CUTLASS
// profiler for performing relative equality checks.
float _epsilon = 0.05f;
float _nonzero_floor = 1.0f / 256.0f;
public:
HostEVTNodeBase(){}
HostEVTNodeBase(bool check_relative_equality):
_check_relative_equality(check_relative_equality) { }
template <
class Element,
class Layout
>
bool equality_check(
cutlass::TensorView<Element, Layout> const& lhs,
cutlass::TensorView<Element, Layout> const& rhs) const {
if (_check_relative_equality) {
return cutlass::reference::host::TensorRelativelyEquals(
lhs, rhs, Element(_epsilon), Element(_nonzero_floor)
);
}
else {
return cutlass::reference::host::TensorEquals(lhs, rhs);
}
}
void* get_tensor_C_ptr() {
return nullptr;
}
void* get_tensor_D_ptr() {
return nullptr;
}
bool compare_reference(std::stringstream& error_ss) {
return true;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Accumulator
template <
typename Gemm
>
class HostAccumulator: public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using TestBedImpl = typename Base::TestBedImpl;
using ElementAccumulator = typename Base::ElementAccumulator;
using ElementCompute = typename Base::ElementCompute;
struct Arguments { };
private:
cutlass::NumericConverter<ElementCompute, ElementAccumulator> accumulator_converter;
public:
HostAccumulator(){}
template<typename ProblemShapeType>
HostAccumulator(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:Base(check_relative_equality) {}
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
return accumulator_converter(acc);
}
Arguments get_arguments() {
return Arguments{};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Scalar Broadcast
template <
typename Gemm,
int Value,
int BroadcastCount = 1,
template <class> class ReductionFn = cutlass::multiplies
>
class HostScalarBroadcast : public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using ElementCompute = typename Base::ElementCompute;
struct Arguments {
ElementCompute scalar[BroadcastCount] = {0};
ElementCompute const* scalar_ptrs[BroadcastCount] = { nullptr };
cute::Stride<cute::_0,cute::_0,cute::_0> dScalar{};
};
private:
ElementCompute _scalar{};
public:
HostScalarBroadcast(){}
template<typename ProblemShapeType, typename TestBedImpl>
HostScalarBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
: Base(check_relative_equality), _scalar(ElementCompute(Value)) {}
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
return _scalar;
}
bool compare_reference(std::stringstream& error_ss) {
error_ss << "Scalar: " << float(_scalar) << "\n\n";
return true;
}
Arguments get_arguments() {
if constexpr (BroadcastCount == 1)
return Arguments{{_scalar}, {nullptr}};
else if constexpr (BroadcastCount == 2)
return Arguments{{_scalar, _scalar}, {nullptr, nullptr}};
else if constexpr (BroadcastCount == 3)
return Arguments{{_scalar, _scalar, _scalar}, {nullptr, nullptr, nullptr}};
else
return Arguments{{_scalar}, {nullptr}};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Row Broadcast
template <
typename Gemm,
typename ElementBias_=void
>
class HostRowBroadcast: public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using ElementBias = std::conditional_t<std::is_void_v<ElementBias_>,
typename Base::ElementC,
ElementBias_>;
using TestBedImpl = typename Base::TestBedImpl;
using ElementCompute = typename Base::ElementCompute;
using LayoutTagVector = cutlass::layout::PackedVectorLayout;
struct Arguments {
ElementBias const* ptr_row = nullptr;
ElementBias null_default = ElementBias(0);
cute::Stride<cute::_0,cute::_1,cute::_0> dRow = {};
};
private:
cutlass::NumericConverter<ElementCompute, ElementBias> _bias_converter;
cutlass::HostTensor<ElementBias, LayoutTagVector> _bias;
int _N;
TestBedImpl impl_;
public:
HostRowBroadcast(){}
template<typename ProblemShapeType>
HostRowBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
: Base(check_relative_equality), impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_N = cute::get<1>(problem_shape_MNKL);
_bias.resize(cutlass::Coord<1>(_N));
EXPECT_TRUE(
detail::initialize_tensor(
_bias.host_view(), cutlass::Distribution::Uniform,
impl_.collective_mma_inputs.seed + 2023
)
);
_bias.sync_device();
}
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
auto TensorBias = cute::make_tensor(_bias.host_data(),
cute::make_layout(cute::make_shape(cute::_1{}, _N)));
return _bias_converter(TensorBias(1, n + n_b));
}
bool compare_reference(std::stringstream& error_ss) {
error_ss
<< "PerColumnBias = \n" << _bias.host_view() << "\n\n";
return true;
}
Arguments get_arguments() {
return {_bias.device_data()};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Column Broadcast
template <
typename Gemm,
typename ElementBias_=void
>
class HostColBroadcast: public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using ElementBias = std::conditional_t<std::is_void_v<ElementBias_>,
typename Base::ElementC,
ElementBias_>;
using TestBedImpl = typename Base::TestBedImpl;
using ElementCompute = typename Base::ElementCompute;
using LayoutTagVector = cutlass::layout::PackedVectorLayout;
struct Arguments {
ElementBias const* ptr_row = nullptr;
ElementBias null_default = ElementBias(0);
cute::Stride<cute::_1,cute::_0,cute::_0> dRow = {};
};
private:
cutlass::NumericConverter<ElementCompute, ElementBias> _bias_converter;
cutlass::HostTensor<ElementBias, LayoutTagVector> _bias;
int _M;
TestBedImpl impl_;
public:
HostColBroadcast(){}
template<typename ProblemShapeType>
HostColBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
: Base(check_relative_equality), impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_M = cute::get<0>(problem_shape_MNKL);
_bias.resize(cutlass::Coord<1>(_M));
EXPECT_TRUE(
detail::initialize_tensor(
_bias.host_view(), cutlass::Distribution::Uniform,
impl_.collective_mma_inputs.seed + 2023
)
);
_bias.sync_device();
}
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
auto TensorBias = cute::make_tensor(_bias.host_data(),
cute::make_layout(cute::make_shape(_M, cute::_1{})));
return _bias_converter(TensorBias(m + m_b, 1));
}
bool compare_reference(std::stringstream& error_ss) {
error_ss
<< "PerRowBias = \n" << _bias.host_view() << "\n\n";
return true;
}
Arguments get_arguments() {
return {_bias.device_data()};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Aux Load
template <
typename Gemm,
bool isC=false,
typename ElementAuxLoad_=void,
typename LayoutTagAux_=void
>
class HostAuxLoad: public HostEVTNodeBase<Gemm> {
public:
using ElementAuxLoad = std::conditional_t<std::is_void_v<ElementAuxLoad_>,
typename HostEVTNodeBase<Gemm>::ElementC,
ElementAuxLoad_>;
using LayoutTagAux = std::conditional_t<std::is_void_v<LayoutTagAux_>,
typename HostEVTNodeBase<Gemm>::LayoutTagC,
LayoutTagAux_>;
using Base = HostEVTNodeBase<Gemm>;
using TestBedImpl = typename Base::TestBedImpl;
using ElementCompute = typename Base::ElementCompute;
using StrideAux = cutlass::gemm::TagToStrideC_t<LayoutTagAux>;
struct Arguments_Aux {
ElementAuxLoad const *ptr_aux = nullptr;
ElementAuxLoad null_default = ElementAuxLoad(0);
StrideAux dAux = {};
};
struct Arguments_C {};
using Arguments = cute::conditional_t<isC, Arguments_C, Arguments_Aux>;
private:
cutlass::NumericConverter<ElementCompute, ElementAuxLoad> _aux_load_converter;
cutlass::HostTensor<ElementAuxLoad, LayoutTagAux> _tensor_aux_load;
int _M, _N, _L;
TestBedImpl impl_;
StrideAux _stride_aux;
public:
HostAuxLoad(){}
template<typename ProblemShapeType>
HostAuxLoad(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
: Base(check_relative_equality), impl_(impl) {
auto problem_shape_NMKL = cute::append<4>(problem_size, 1);
auto [_M, _N, K, _L] = problem_shape_NMKL;
auto aux_coord = cutlass::make_Coord(_M * _L, _N);
_tensor_aux_load.resize(
aux_coord,
cutlass::layout::Affine2Layout_Factory<LayoutTagAux>::layout_factory(
aux_coord, typename LayoutTagAux::Stride()
)
);
EXPECT_TRUE(
detail::initialize_tensor(
_tensor_aux_load.host_view(),
cutlass::Distribution::Uniform,
impl_.collective_mma_inputs.seed + 2023
)
);
_tensor_aux_load.sync_device();
_stride_aux = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(_M, _N, _L));
}
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
auto TensorAuxLoad = cute::make_tensor(_tensor_aux_load.host_data(),
cute::make_layout(cute::make_shape(_M, _N, _L), _stride_aux));
return _aux_load_converter(TensorAuxLoad(m + m_b, n + n_b, l));
}
bool compare_reference(std::stringstream& error_ss) {
if constexpr (!isC) {
error_ss
<< "AuxLoad = \n" << _tensor_aux_load.host_view()<< "\n\n";
}
return true;
}
void* get_tensor_C_ptr() {
if constexpr (isC) {
return static_cast<void*>(_tensor_aux_load.device_data());
} else {
return nullptr;
}
}
Arguments get_arguments() {
if constexpr (isC)
return {};
else
return {_tensor_aux_load.device_data(), ElementAuxLoad(0), _stride_aux};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Compute
template<typename T>
T* findNonNullPtr(T* first_ptr) {
return first_ptr;
}
template <typename T, typename... Args>
T* findNonNullPtr(T* first_ptr, Args... args) {
if (first_ptr) {
return first_ptr;
}
return findNonNullPtr(args...);
}
template <
typename Gemm,
template <class> class ComputeOp_
>
class HostCompute: public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using ElementCompute = typename Base::ElementCompute;
using ComputeOp = ComputeOp_<ElementCompute>;
struct Arguments {
struct OpArgs {} op;
};
private:
ComputeOp _op;
public:
HostCompute(){}
template <typename ProblemShapeType, typename TestBedImpl>
HostCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
Base(check_relative_equality) { }
template <class ElementAccumulator, typename... Args>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc, Args... frg_inputs) {
return _op(frg_inputs...);
}
Arguments get_arguments(){
return {};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Unary Compute
template <
typename Gemm,
template <class> class ComputeOp_,
typename Child0
>
class HostUnaryCompute: public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using ElementCompute = typename Base::ElementCompute;
using ComputeOp = ComputeOp_<ElementCompute>;
struct Arguments {
typename Child0::Arguments child_0_args;
struct OpArgs {} op;
};
private:
ComputeOp _op;
Child0 _child_0;
public:
HostUnaryCompute(){}
template <typename ProblemShapeType, typename TestBedImpl>
HostUnaryCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
Base(check_relative_equality),
_child_0(problem_size, impl, check_relative_equality) { }
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
ElementCompute child_0_result = _child_0.visit(m, n, l, m_b, n_b, acc);
return _op(child_0_result);
}
Arguments get_arguments(){
return {
_child_0.get_arguments(),
{},
};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Aux Store
template <
typename Gemm,
bool isD=false,
class ElementAuxStore_=void,
typename LayoutTagAux_=void
>
class HostAuxStore: public HostEVTNodeBase<Gemm> {
public:
using ElementAuxStore = std::conditional_t<std::is_void_v<ElementAuxStore_>,
typename HostEVTNodeBase<Gemm>::ElementD,
ElementAuxStore_>;
using LayoutTagAux = std::conditional_t<std::is_void_v<LayoutTagAux_>,
typename HostEVTNodeBase<Gemm>::LayoutTagD,
LayoutTagAux_>;
using Base = HostEVTNodeBase<Gemm>;
using TestBedImpl = typename Base::TestBedImpl;
using ElementCompute = typename Base::ElementCompute;
using StrideAux = cutlass::gemm::TagToStrideC_t<LayoutTagAux>;
struct Arguments_Aux {
struct OpArgs {
ElementAuxStore* ptr_aux = nullptr;
StrideAux dAux = {};
} op;
};
struct Arguments_D {};
using Arguments = cute::conditional_t<isD, Arguments_D, Arguments_Aux>;
private:
cutlass::NumericConverter<ElementAuxStore, ElementCompute> destination_converter;
cutlass::HostTensor<ElementAuxStore, LayoutTagAux> _tensor_aux_store;
cutlass::HostTensor<ElementAuxStore, LayoutTagAux> _reference_aux_store;
int _M, _N, _L;
TestBedImpl impl_;
StrideAux _stride_aux;
public:
HostAuxStore(){}
template <typename ProblemShapeType>
HostAuxStore(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
Base(check_relative_equality),
impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [_M, _N, K, _L] = problem_shape_MNKL;
auto aux_coord = cutlass::make_Coord(_M * _L, _N);
_tensor_aux_store.resize(
aux_coord,
cutlass::layout::Affine2Layout_Factory<LayoutTagAux>::layout_factory(
aux_coord, typename LayoutTagAux::Stride()
)
);
_reference_aux_store.resize(
aux_coord,
cutlass::layout::Affine2Layout_Factory<LayoutTagAux>::layout_factory(
aux_coord, typename LayoutTagAux::Stride()
)
);
_tensor_aux_store.sync_device();
_stride_aux = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(_M, _N, _L));
}
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc, ElementCompute child_0_result) {
auto TensorAuxStore = cute::make_tensor(static_cast<ElementAuxStore*>(_reference_aux_store.host_data()),
cute::make_layout(cute::make_shape(_M, _N, _L), _stride_aux));
TensorAuxStore(m + m_b, n + n_b, l) = destination_converter(child_0_result);
return child_0_result;
}
bool compare_reference(std::stringstream& error_ss) {
// Verify the store node
_tensor_aux_store.sync_host();
bool equal = this->equality_check(_reference_aux_store.host_view(), _tensor_aux_store.host_view());
if (!equal) {
error_ss
<< "\n\nReference =\n" << _reference_aux_store.host_view()
<< "\n\nComputed =\n" << _tensor_aux_store.host_view() << "\n\n";
}
return equal;
}
void* get_tensor_D_ptr() {
if constexpr (isD)
return static_cast<void*>(_tensor_aux_store.device_data());
else
return nullptr;
}
Arguments get_arguments() {
if constexpr (isD) {
return {};
} else {
return {_tensor_aux_store.device_data(), _stride_aux};
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Row Reduce
template <
typename Gemm,
template <class> class ReduceFn,
typename ElementReduce
>
class HostRowReduce: public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using TestBedImpl = typename Base::TestBedImpl;
using ElementCompute = typename Base::ElementCompute;
using ElementOutput = typename Base::ElementD;
using LayoutTagVector = cutlass::layout::PackedVectorLayout;
struct Arguments {
struct OpArgs {
ElementReduce* ptr_row = nullptr;
ElementCompute reduce_identity = 0;
cute::Stride<cute::_0, cute::_1, cute::_0> dRow = {};
} op;
};
private:
cutlass::NumericConverter<ElementReduce, ElementCompute> destination_converter;
cutlass::HostTensor<ElementReduce, LayoutTagVector> _tensor_row_reduce;
cutlass::HostTensor<ElementCompute, LayoutTagVector> _reduce_buffer;
cutlass::HostTensor<ElementReduce, LayoutTagVector> _reference_row_reduce;
int _N;
TestBedImpl impl_;
ReduceFn<ElementCompute> reduce_fn;
public:
HostRowReduce(){}
template <typename ProblemShapeType>
HostRowReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
Base(check_relative_equality),
impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_N = cute::get<1>(problem_shape_MNKL);
_tensor_row_reduce.resize(cutlass::Coord<1>(_N));
_reference_row_reduce.resize(cutlass::Coord<1>(_N));
_reduce_buffer.resize(cutlass::Coord<1>(_N));
_tensor_row_reduce.sync_device();
}
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc, ElementCompute child_0_result) {
auto TensorRowReduce = cute::make_tensor(_reduce_buffer.host_data(),
cute::make_layout(cute::make_shape(cute::_1{}, _N)));
TensorRowReduce(1, n + n_b) = reduce_fn(TensorRowReduce(1, n + n_b), child_0_result);
return child_0_result;
}
bool compare_reference(std::stringstream& error_ss) {
// Verify the store node
_tensor_row_reduce.sync_host();
auto TensorRowReduce = cute::make_tensor(_reference_row_reduce.host_data(),
cute::make_layout(cute::make_shape(cute::_1{}, _N)));
auto TensorReduceBuffer = cute::make_tensor(_reduce_buffer.host_data(),
cute::make_layout(cute::make_shape(cute::_1{}, _N)));
// Filling the reference tensor with the reduce buffer
for (int n = 0; n < _N; n ++) {
TensorRowReduce(1, n) = destination_converter(TensorReduceBuffer(1, n));
}
bool equal = this->equality_check(_reference_row_reduce.host_view(), _tensor_row_reduce.host_view());
if (!equal) {
error_ss
<< "\n\nRow Reduce Reference =\n" << _reference_row_reduce.host_view()
<< "\n\nRow Reduce Computed =\n" << _tensor_row_reduce.host_view() << "\n\n";
}
return equal;
}
Arguments get_arguments() {
return {_tensor_row_reduce.device_data()};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Column Reduce
template <
typename Gemm,
template <class> class ReduceFn,
typename ElementReduce
>
class HostColumnReduce: public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using TestBedImpl = typename Base::TestBedImpl;
using ElementCompute = typename Base::ElementCompute;
using ElementOutput = typename Base::ElementD;
using LayoutTagVector = cutlass::layout::PackedVectorLayout;
struct Arguments {
struct OpArgs {
ElementReduce* ptr_col = nullptr;
ElementCompute reduce_identity = 0;
cute::Stride<cute::_1, cute::_0, cute::_0> dRow = {};
} op;
};
private:
cutlass::NumericConverter<ElementReduce, ElementCompute> destination_converter;
cutlass::HostTensor<ElementReduce, LayoutTagVector> _tensor_column_reduce;
cutlass::HostTensor<ElementCompute, LayoutTagVector> _reduce_buffer;
cutlass::HostTensor<ElementReduce, LayoutTagVector> _reference_column_reduce;
int _M;
TestBedImpl impl_;
ReduceFn<ElementCompute> reduce_fn;
public:
HostColumnReduce(){}
template <typename ProblemShapeType>
HostColumnReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
Base(check_relative_equality),
impl_(impl) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
_M = cute::get<0>(problem_shape_MNKL);
_tensor_column_reduce.resize(cutlass::Coord<1>(_M));
_reference_column_reduce.resize(cutlass::Coord<1>(_M));
_reduce_buffer.resize(cutlass::Coord<1>(_M));
_tensor_column_reduce.sync_device();
}
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc, ElementCompute child_0_result) {
auto TensorColReduce = cute::make_tensor(_reduce_buffer.host_data(),
cute::make_layout(cute::make_shape(_M, cute::_1{})));
TensorColReduce(m + m_b, 1) = reduce_fn(TensorColReduce(m + m_b, 1), child_0_result);
return child_0_result;
}
bool compare_reference(std::stringstream& error_ss) {
// Verify the store node
_tensor_column_reduce.sync_host();
auto TensorColReduce = cute::make_tensor(_reference_column_reduce.host_data(),
cute::make_layout(cute::make_shape(_M, cute::_1{})));
auto TensorReduceBuffer = cute::make_tensor(_reduce_buffer.host_data(),
cute::make_layout(cute::make_shape(_M, cute::_1{})));
// Filling the reference tensor with the reduce buffer
for (int m = 0; m < _M; m ++) {
TensorColReduce(m, 1) = destination_converter(TensorReduceBuffer(m, 1));
}
bool equal = this->equality_check(_reference_column_reduce.host_view(), _tensor_column_reduce.host_view());
if (!equal) {
error_ss
<< "\n\nColumn Reduce Reference =\n" << _reference_column_reduce.host_view()
<< "\n\nColumn Reduce Computed =\n" << _tensor_column_reduce.host_view() << "\n\n";
}
return equal;
}
Arguments get_arguments() {
return {_tensor_column_reduce.device_data()};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// EVT - Scalar Reduce
template <
typename Gemm,
template <class> class ReduceFn,
typename ElementReduce
>
class HostScalarReduce: public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using TestBedImpl = typename Base::TestBedImpl;
using ElementCompute = typename Base::ElementCompute;
using ElementOutput = typename Base::ElementD;
using LayoutTagVector = cutlass::layout::PackedVectorLayout;
struct Arguments {
struct OpArgs {
ElementReduce* ptr_scalar = nullptr;
ElementCompute reduce_identity = 0;
cute::Stride<cute::_0, cute::_0, cute::_0> dScalar = {};
} op;
};
private:
cutlass::NumericConverter<ElementReduce, ElementCompute> destination_converter;
cutlass::HostTensor<ElementReduce, LayoutTagVector> _tensor_scalar_reduce;
cutlass::HostTensor<ElementCompute, LayoutTagVector> _reduce_buffer;
cutlass::HostTensor<ElementReduce, LayoutTagVector> _reference_scalar_reduce;
ReduceFn<ElementCompute> reduce_fn;
TestBedImpl impl_;
public:
HostScalarReduce(){}
template <typename ProblemShapeType>
HostScalarReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false):
Base(check_relative_equality),
impl_(impl) {
_tensor_scalar_reduce.resize(cutlass::Coord<1>(1));
_reference_scalar_reduce.resize(cutlass::Coord<1>(1));
_reduce_buffer.resize(cutlass::Coord<1>(1));
_tensor_scalar_reduce.sync_device();
}
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc, ElementCompute child_0_result) {
auto TensorRowReduce = cute::make_tensor(_reduce_buffer.host_data(),
cute::make_layout(cute::make_shape(cute::_1{})));
TensorRowReduce(0) = reduce_fn(TensorRowReduce(0), child_0_result);
return child_0_result;
}
bool compare_reference(std::stringstream& error_ss) {
// Verify the store node
_tensor_scalar_reduce.sync_host();
auto TensorRowReduce = cute::make_tensor(_reference_scalar_reduce.host_data(),
cute::make_layout(cute::make_shape(cute::_1{})));
auto TensorReduceBuffer = cute::make_tensor(_reduce_buffer.host_data(),
cute::make_layout(cute::make_shape(cute::_1{})));
// Filling the reference tensor with the reduce buffer
TensorRowReduce(0) = destination_converter(TensorReduceBuffer(0));
bool equal = this->equality_check(_reference_scalar_reduce.host_view(), _tensor_scalar_reduce.host_view());
if (!equal) {
error_ss
<< "\n\nScalar Reduce Reference =\n" << _reference_scalar_reduce.host_view()
<< "\n\nScalar Reduce Computed =\n" << _tensor_scalar_reduce.host_view() << "\n\n";
}
return equal;
}
Arguments get_arguments() {
return {_tensor_scalar_reduce.device_data()};
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Host EVT wrapper
/// The ArgumentPack is used to model the alignment when num ops <= 4
template <typename... Ops>
struct ArgumentPack;
template <typename T>
struct ArgumentPack<T> {
T arg;
ArgumentPack(T first):
arg(first) {}
};
template <typename First, typename... Rest>
struct ArgumentPack<First, Rest...> {
First arg;
ArgumentPack<Rest...> rest_args;
ArgumentPack(First first, Rest... rest) :
arg(first), rest_args(rest...) {}
};
/// Base class for Host Visitor
template <typename Gemm, class... Ops>
struct HostVisitorBase: public HostEVTNodeBase<Gemm> {
public:
using Base = HostEVTNodeBase<Gemm>;
using ElementCompute = typename Base::ElementCompute;
using Arguments_struct = ArgumentPack<typename Ops::Arguments...>;
using Arguments_tuple = cute::tuple<typename Ops::Arguments...>;
constexpr static int Rm1 = sizeof...(Ops);
constexpr static bool cond = Rm1 > 4;
using Arguments = cute::conditional_t<cond, Arguments_tuple, Arguments_struct>;
std::tuple<Ops...> ops;
HostVisitorBase(){}
template<typename ProblemShapeType, typename TestBedImpl>
HostVisitorBase(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:Base(check_relative_equality),
ops(test::gemm::device::tapply(std::tuple<Ops...>{},
[&] (auto&& op) {
using Op = cute::remove_cvref_t<decltype(op)>;
return Op(problem_size, impl, check_relative_equality);
},
[] (auto&&... _ops) {
return std::make_tuple(_ops...);
},
cute::make_seq<Rm1>{}
)){ }
bool compare_reference(std::stringstream& error_ss) {
return cute::detail::tapply(ops,
[&](auto& op) {
return op.compare_reference(error_ss);
},
[&] (auto&&... inputs) {
return arrayAnd(inputs...);
},
cute::make_seq<Rm1>{}
);
}
void* get_tensor_C_ptr() {
return cute::detail::tapply(ops,
[&](auto& op) {
return op.get_tensor_C_ptr();
},
[&] (auto&&... inputs) {
return findNonNullPtr(inputs...);
},
cute::make_seq<Rm1>{}
);
}
void* get_tensor_D_ptr() {
return cute::detail::tapply(ops,
[&](auto& op) {
return op.get_tensor_D_ptr();
},
[&] (auto&&... inputs) {
return findNonNullPtr(inputs...);
},
cute::make_seq<Rm1>{}
);
}
Arguments get_arguments() {
return test::gemm::device::tapply(ops,
[&](auto& op) {
return op.get_arguments();
},
[&] (auto&&... args) {
if constexpr (Rm1 > 4) {
return cute::make_tuple(args...);
} else {
return Arguments(args...);
}
},
cute::make_seq<Rm1>{}
);
}
bool arrayAnd(bool passed) {
return passed;
}
template <typename... Args>
bool arrayAnd(bool first_passed, Args... passed) {
if (first_passed) {
return arrayAnd(passed...);
}
return first_passed;
}
};
/// Tree-struct visitor
template <class NodeOp, class... ChildOps>
struct HostTreeVisitor: public HostVisitorBase<typename NodeOp::Base::Gemm, ChildOps..., NodeOp> {
public:
using Gemm = typename NodeOp::Base::Gemm;
using Base = HostVisitorBase<Gemm, ChildOps..., NodeOp>;
using ElementCompute = typename Base::ElementCompute;
using Arguments = typename Base::Arguments;
constexpr static int Rm1 = sizeof...(ChildOps);
HostTreeVisitor(){}
template<typename ProblemShapeType, typename TestBedImpl>
HostTreeVisitor(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:Base(problem_size, impl, check_relative_equality){ }
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
return cute::detail::tapply(this->ops,
[&] (auto& op) {
return op.visit(m, n, l, m_b, n_b, acc);
},
[&] (auto&&... frg_inputs) {
return std::get<Rm1>(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...);
},
cute::make_seq<Rm1>{}
);
}
};
/// General Graph visitor
template <class Gemm, class EdgeTuple, class... Ops>
struct HostTopoVisitor: public HostVisitorBase<Gemm, Ops...> {
public:
using Base = HostVisitorBase<Gemm, Ops...>;
using ElementCompute = typename Base::ElementCompute;
constexpr static int Rm1 = Base::Rm1;
using Arguments = typename Base::Arguments;
private:
ElementCompute frg_outputs[Rm1];
public:
HostTopoVisitor(){}
template<typename ProblemShapeType, typename TestBedImpl>
HostTopoVisitor(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:Base(problem_size, impl, check_relative_equality) { }
template<class ElementAccumulator, int I>
ElementCompute visit_(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
frg_outputs[I] = cute::transform_apply(cute::get<I>(EdgeTuple{}),
[&] (auto&& _E) {
constexpr int e = cute::remove_cvref_t<decltype(_E)>::value;
return frg_outputs[e];
},
[&] (auto const&... frg_inputs) {
ElementCompute res = std::get<I>(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...);
return res;
}
);
if constexpr (I < Rm1 - 1) {
return visit_<ElementAccumulator, I+1>(m, n, l, m_b, n_b, acc);
} else {
return frg_outputs[I];
}
}
template <class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
return visit_<ElementAccumulator, 0>(m, n, l, m_b, n_b, acc);
}
};
/// SplitTree visitor
template <class Gemm, class InputTree, class OutputTree, class... AuxOutTrees>
struct HostSplitTreeVisitor: public HostVisitorBase<Gemm, InputTree, AuxOutTrees..., OutputTree> {
public:
using Base = HostVisitorBase<Gemm, InputTree, AuxOutTrees..., OutputTree>;
using ElementCompute = typename Base::ElementCompute;
using Arguments = typename Base::Arguments;
constexpr static int Rm2 = sizeof...(AuxOutTrees);
private:
ElementCompute frg_input;
public:
HostSplitTreeVisitor(){}
template<typename ProblemShapeType, typename TestBedImpl>
HostSplitTreeVisitor(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false)
:Base(problem_size, impl, check_relative_equality) { }
template<class ElementAccumulator, int I>
void visitAux(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator frag) {
std::get<I+1>(this->ops).visit(m, n, l, m_b, n_b, frag);
if constexpr (I < Rm2 - 1) {
return visitAux<ElementAccumulator, I+1>(m, n, l, m_b, n_b, frag);
} else {
return;
}
}
template<class ElementAccumulator>
ElementCompute visit(
int64_t m, int64_t n, int64_t l, int m_b, int n_b,
ElementAccumulator acc) {
/// Compute the input tree
frg_input = std::get<0>(this->ops).visit(m, n, l, m_b, n_b, acc);
/// Compute the aux out tree
visitAux<ElementAccumulator, 0>(m, n, l, m_b, n_b, frg_input);
/// Visit the output tree
return std::get<Rm2+1>(this->ops).visit(m, n, l, m_b, n_b, frg_input);
}
};
/// Universal testbed for EVT
template <class Gemm, typename EVT>
class Testbed3xEVT {
public:
// The EVT Module to test
using EVTModule = typename EVT::EVTModule;
using TestBedImpl = typename detail::TestbedImpl<Gemm, cutlass::epilogue::thread::Identity, true>;
using Kernel = typename Gemm::GemmKernel;
using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue;
using ElementAccumulator = typename Kernel::ElementAccumulator;
using ElementC = typename Kernel::ElementC;
using ElementD = typename Kernel::ElementD;
using ProblemShapeType = typename Kernel::ProblemShape;
using LayoutTagA = typename TestBedImpl::LayoutTagA;
using LayoutTagB = typename TestBedImpl::LayoutTagB;
using LayoutTagC = typename TestBedImpl::LayoutTagC;
using LayoutTagD = typename TestBedImpl::LayoutTagD;
//
// Methods
//
Testbed3xEVT(
bool check_relative_equality_,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = TestBedImpl::kDefaultSeed
) :
impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorBeta::ENABLED,
init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_),
check_relative_equality(check_relative_equality_) { }
Testbed3xEVT(
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = TestBedImpl::kDefaultSeed
) :
impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorBeta::ENABLED,
init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_),
check_relative_equality(false) { }
Testbed3xEVT(
typename LayoutTagA::Stride stride_factor_A_,
typename LayoutTagB::Stride stride_factor_B_,
typename LayoutTagC::Stride stride_factor_C_,
typename LayoutTagD::Stride stride_factor_D_,
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
uint64_t seed_ = TestBedImpl::kDefaultSeed
) :
impl_(stride_factor_A_, stride_factor_B_, stride_factor_C_, stride_factor_D_,
CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorBeta::ENABLED,
init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_),
check_relative_equality(false) { }
/// Initializes data structures
void initialize(ProblemShapeType problem_size) {
//
// Allocate the GEMM workspace for A/B tensor
//
impl_.initialize(problem_size);
}
// Detail Implementation
TestBedImpl impl_;
// Whether to use relative equality checks
bool check_relative_equality;
bool verify(ProblemShapeType problem_size, EVTModule& host_reference) {
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto M = cute::get<0>(problem_shape_MNKL);
auto N = cute::get<1>(problem_shape_MNKL);
auto K = cute::get<2>(problem_shape_MNKL);
auto L = cute::get<3>(problem_shape_MNKL);
auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(),
cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a));
auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(),
cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b));
auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d);
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
/// Reference Kernel
static int constexpr kBlockM = 64;
static int constexpr kBlockN = 64;
#if defined(_OPENMP)
#pragma omp parallel for collapse(3)
#endif
for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) {
for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) {
for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) {
ElementAccumulator acc[kBlockM][kBlockN];
gett_mainloop(mainloop_params, m, n, l, acc);
/// Epilogue EVT
for (int n_b = 0; n_b < kBlockN; ++n_b) {
for (int m_b = 0; m_b < kBlockM; ++m_b) {
if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) {
host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]);
}
}
}
}
}
}
std::stringstream error_ss;
bool passed = host_reference.compare_reference(error_ss);
if (!passed) {
std::stringstream fname;
fname << "error_Gemm_device_"
<< M << "x" << N << "x" << K << "x" << L << "_"
<< cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_"
<< cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_"
<< cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt";
std::ofstream file(fname.str());
file
<< "problem: " << ' ' << M << "x" << N << "x" << K
<< ", Batch count = " << L << "\n\n";
file
<< "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view()
<< "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view()
<< "\nC =\n" << impl_.collective_epilogue.tensor_C.host_view() << "\n\n";
file << error_ss.str();
}
return passed;
}
bool run(
ProblemShapeType problem_size,
bool profiling = false,
int iterations = 20,
int splits = 1) {
// Fail test if insufficient CUDA device
if (!impl_.sufficient()) {
std::cout << "Test failed due to insufficient CUDA device." << std::endl;
return false;
}
//
// Initialize the Gemm operator
//
typename Gemm::Arguments arguments;
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
if (not profiling) {
impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
hw_info.sm_count = impl_.sm_count;
}
else {
impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
hw_info.sm_count = impl_.sm_count;
}
typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args;
if constexpr (cute::is_same_v<typename Gemm::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
scheduler_args = { splits };
}
/// Initializes data structures
/// A/B/C/D Tensor
initialize(problem_size);
/// Initialize the epilogue arguments
EVTModule host_reference(problem_size, impl_, check_relative_equality);
arguments = typename Gemm::Arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{
impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a,
impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b
},
{ // Epilogue arguments
{}, // thread
static_cast<ElementC*>(host_reference.get_tensor_C_ptr()),
impl_.collective_epilogue.stride_c,
static_cast<ElementD*>(host_reference.get_tensor_D_ptr()),
impl_.collective_epilogue.stride_d
}, // Epilogue arguments end
hw_info,
scheduler_args
};
// Filling in the thread arguments
typename EVTModule::Arguments epilogue_args = host_reference.get_arguments();
std::memcpy(&arguments.epilogue.thread, &epilogue_args.arg, sizeof(epilogue_args.arg));
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
cudaError_t error = cudaGetLastError();
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
return true;
}
//
// Run the GEMM
//
if (profiling) {
return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace);
}
else {
cudaError_t result;
status = gemm_op.initialize(arguments, workspace.get());
status = gemm_op.run();
result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync.";
return false;
}
}
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
//
// Verify
//
bool passed = this->verify(problem_size, host_reference);
if (!passed) {
std::cout << "Error : Failed \n";
}
return passed;
}
};
template <typename Gemm, typename EVT>
bool TestAllEVT(bool check_relative_equality=false) {
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB);
std::vector<int> problem_size_m = {max_alignment, 512 - 3 * max_alignment};
std::vector<int> problem_size_n = {max_alignment, 512 - 2 * max_alignment};
if constexpr (cute::is_same_v<typename Gemm::GemmKernel::DispatchPolicy::Schedule,
cutlass::gemm::KernelTmaWarpSpecializedPingpong>) {
problem_size_m.push_back(768);
problem_size_n.push_back(768);
}
constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages;
constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{});
std::vector<int> problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment};
Testbed3xEVT<Gemm, EVT> testbed(check_relative_equality);
bool passed = true;
for (int m : problem_size_m) {
for (int n : problem_size_n) {
for (int k : problem_size_k) {
ProblemShapeType problem_size;
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
}
else {
problem_size = ProblemShapeType{m, n, k};
}
passed = testbed.run(problem_size);
if (!passed) {
return false;
}
}
}
}
// if we do support batched GEMM, just run one test on it to save on test time
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3};
passed = testbed.run(
problem_size
);
if (!passed) {
return false;
}
}
return passed;
}
} // namespace device
} // namespace gemm
} // namespace test
/////////////////////////////////////////////////////////////////////////////////////////////////