mirror of https://github.com/QMCPACK/qmcpack.git
delicate generic function generic_size
This commit is contained in:
parent
f2cc3666f9
commit
e781b569b2
|
@ -388,6 +388,14 @@ void BatchedProduct(char TA,
|
|||
gemmBatched(TB, TA, M, N, K, element(alpha), Bi.data(), ldb, Ai.data(), lda, element(beta), Ci.data(), ldc, nbatch);
|
||||
}
|
||||
|
||||
template<class T> auto generic_sizes(T const& A)
|
||||
->decltype(std::array<std::size_t, 2>{A.size(0), A.size(1)}) {
|
||||
return std::array<std::size_t, 2>{A.size(0), A.size(1)}; }
|
||||
|
||||
template<class T> auto generic_sizes(T const& A)
|
||||
->decltype(A.sizes()) {
|
||||
return A.sizes(); }
|
||||
|
||||
// no batched sparse product yet, serialize call
|
||||
template<class T,
|
||||
class MultiArrayPtr2DA,
|
||||
|
@ -428,7 +436,7 @@ void BatchedProduct(char TA,
|
|||
|
||||
for (int i = 0; i < nbatch; i++)
|
||||
{
|
||||
csrmm(TA, (*A[i]).size(), (*B[i]).size(1), (*A[i]).size(1), elementA(alpha), "GxxCxx",
|
||||
csrmm(TA, (*A[i]).size(), std::get<1>(generic_sizes(*B[i])), std::get<1>(generic_sizes(*A[i])), elementA(alpha), "GxxCxx",
|
||||
pointer_dispatch((*A[i]).non_zero_values_data()), pointer_dispatch((*A[i]).non_zero_indices2_data()),
|
||||
pointer_dispatch((*A[i]).pointers_begin()), pointer_dispatch((*A[i]).pointers_end()),
|
||||
pointer_dispatch((*B[i]).origin()), (*B[i]).stride(), elementA(beta), pointer_dispatch((*C[i]).origin()),
|
||||
|
|
Loading…
Reference in New Issue