delicate generic function generic_size

This commit is contained in:
Alfredo Correa 2022-08-26 15:59:42 -07:00
parent f2cc3666f9
commit e781b569b2
1 changed files with 9 additions and 1 deletions

View File

@ -388,6 +388,14 @@ void BatchedProduct(char TA,
gemmBatched(TB, TA, M, N, K, element(alpha), Bi.data(), ldb, Ai.data(), lda, element(beta), Ci.data(), ldc, nbatch);
}
template<class T> auto generic_sizes(T const& A)
->decltype(std::array<std::size_t, 2>{A.size(0), A.size(1)}) {
return std::array<std::size_t, 2>{A.size(0), A.size(1)}; }
template<class T> auto generic_sizes(T const& A)
->decltype(A.sizes()) {
return A.sizes(); }
// no batched sparse product yet, serialize call
template<class T,
class MultiArrayPtr2DA,
@ -428,7 +436,7 @@ void BatchedProduct(char TA,
for (int i = 0; i < nbatch; i++)
{
csrmm(TA, (*A[i]).size(), (*B[i]).size(1), (*A[i]).size(1), elementA(alpha), "GxxCxx",
csrmm(TA, (*A[i]).size(), std::get<1>(generic_sizes(*B[i])), std::get<1>(generic_sizes(*A[i])), elementA(alpha), "GxxCxx",
pointer_dispatch((*A[i]).non_zero_values_data()), pointer_dispatch((*A[i]).non_zero_indices2_data()),
pointer_dispatch((*A[i]).pointers_begin()), pointer_dispatch((*A[i]).pointers_end()),
pointer_dispatch((*B[i]).origin()), (*B[i]).stride(), elementA(beta), pointer_dispatch((*C[i]).origin()),