add spmm
This commit is contained in:
parent
b65f1c2036
commit
13df2bd8b7
19
README.md
19
README.md
|
@ -21,7 +21,7 @@ python setup.py develop
|
|||
### Architecture
|
||||
|
||||
```
|
||||
- JSparse
|
||||
- jsparse
|
||||
- nn
|
||||
- functional
|
||||
- modules
|
||||
|
@ -29,7 +29,7 @@ python setup.py develop
|
|||
collate/quantize/utils.py
|
||||
```
|
||||
|
||||
You can use the modules from `JSparse/modules` .
|
||||
You can use the modules from `jsparse/modules` .
|
||||
|
||||
### Sparse Tensor
|
||||
|
||||
|
@ -50,12 +50,12 @@ We can then use `sparse_collate_fn` (provided in `JSparse.utils.collate`) to ass
|
|||
|
||||
### Sparse Neural Network
|
||||
|
||||
We finished many common modules in `JSparse.nn` such like `MaxPool`, `GlobalMaxPool`.
|
||||
We finished many common modules in `jsparse.nn` such like `GlobalPool`.
|
||||
|
||||
The neural network interface in JSparse is similar to Jittor:
|
||||
The neural network interface in jsparse is similar to Jittor:
|
||||
|
||||
```python
|
||||
import JSparse.nn as spnn
|
||||
import jsparse.nn as spnn
|
||||
def get_conv_block(self, in_channel, out_channel, kernel_size, stride):
|
||||
return nn.Sequential(
|
||||
spnn.Conv3d(
|
||||
|
@ -111,12 +111,17 @@ model = nn.Sequential(
|
|||
We finnished two versions of Sparse Convolution(completed convolution function with jittor operators or cuda).
|
||||
|
||||
We choose attribute `batch_size = 2, total_len = 10` and run on RTX3080 to test per iteration's speed (JSparse's version is `v0.5.0` ).
|
||||
| | JSparse | TorchSparse(v1.4.0) |
|
||||
|-------------------|---------------|----------------------|
|
||||
| voxel_size = 0.50 | 20.05ms | 33.66ms |
|
||||
| voxel_size = 0.10 | 25.15ms | 40.40ms |
|
||||
| voxel_size = 0.02 | 81.37ms | 87.42ms |
|
||||
|
||||
| | JSparse(jittor) | JSparse(cuda) | TorchSparse(v1.4.0) |
|
||||
<!-- | | JSparse(jittor) | JSparse(cuda) | TorchSparse(v1.4.0) |
|
||||
|-------------------|-----------------|---------------|----------------------|
|
||||
| voxel_size = 0.50 | 26.60ms | 20.05ms | 33.66ms |
|
||||
| voxel_size = 0.10 | 32.34ms | 25.15ms | 40.40ms |
|
||||
| voxel_size = 0.02 | 86.89ms | 81.37ms | 87.42ms |
|
||||
| voxel_size = 0.02 | 86.89ms | 81.37ms | 87.42ms | -->
|
||||
|
||||
We also test the same 200 scenes of ScanNet on [VMNet](https://github.com/hzykent/VMNet) on JSparse and TorchSparse.
|
||||
|
||||
|
|
|
@ -0,0 +1,270 @@
|
|||
from unittest import result
|
||||
from typing import Union, Tuple
|
||||
|
||||
import jittor as jt
|
||||
from jittor import Function
|
||||
|
||||
from jsparse import SparseTensor
|
||||
|
||||
__all__ = ['spmm']
|
||||
|
||||
|
||||
def spmm(
|
||||
rows: jt.Var,
|
||||
cols: jt.Var,
|
||||
vals: jt.Var,
|
||||
size: Union[Tuple[int, int], jt.NanoVector],
|
||||
mat: jt.Var,
|
||||
is_sorted: bool = False,
|
||||
cuda_spmm_alg: int = 1,
|
||||
) -> jt.Var:
|
||||
|
||||
assert len(rows) == len(cols), "Invalid length"
|
||||
assert len(rows) == len(vals), "Invalid length"
|
||||
assert vals.dtype == mat.dtype, "dtype mismatch"
|
||||
|
||||
if jt.flags.use_cuda > 0:
|
||||
rows = rows.int32()
|
||||
cols = cols.int32()
|
||||
output_size = (mat.shape[1], size[0])
|
||||
result = jt.code(output_size, vals.dtype, [rows, cols, vals, mat],
|
||||
cuda_header="""
|
||||
#undef out
|
||||
#include <assert.h>
|
||||
#include <executor.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cusparse.h>
|
||||
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/tuple.h>
|
||||
#include <thrust/iterator/zip_iterator.h>
|
||||
|
||||
template <typename scalar_t>
|
||||
cudaDataType getDtype(const scalar_t *ptr) {
|
||||
assert((std::is_same<scalar_t, jittor::float32>::value || std::is_same<scalar_t, jittor::float64>::value));
|
||||
//if (std::is_same<scalar_t, jittor::float32>::value)
|
||||
// return CUDA_R_32F;
|
||||
//else if (std::is_same<scalar_t, jittor::float64>::value)
|
||||
// return CUDA_R_64F;
|
||||
return std::is_same<scalar_t, jittor::float32>::value ? CUDA_R_32F : CUDA_R_64F;
|
||||
}
|
||||
""",
|
||||
cuda_src="""
|
||||
@alias(rows, in0)
|
||||
@alias(cols, in1)
|
||||
@alias(vals, in2)
|
||||
@alias(mat2, in3)
|
||||
@alias(result, out0)
|
||||
""" + f"""
|
||||
const int64_t dim_i = {size[0]};
|
||||
const int64_t dim_j = {size[1]};
|
||||
const int64_t spmm_algorithm_id = {cuda_spmm_alg};
|
||||
const bool is_sorted = {'true' if is_sorted else 'false'};
|
||||
const int64_t nnz = {rows.numel()};
|
||||
""" + """
|
||||
const bool is_int32 = true;
|
||||
|
||||
cusparseHandle_t handle = 0;
|
||||
cusparseCreate(&handle);
|
||||
|
||||
cusparseSpMMAlg_t mm_alg;
|
||||
switch (spmm_algorithm_id) {
|
||||
case 1:
|
||||
mm_alg = CUSPARSE_COOMM_ALG1;
|
||||
break;
|
||||
case 2:
|
||||
mm_alg = CUSPARSE_COOMM_ALG2;
|
||||
break;
|
||||
case 3:
|
||||
mm_alg = CUSPARSE_COOMM_ALG3;
|
||||
break;
|
||||
case 4:
|
||||
mm_alg = CUSPARSE_SPMM_COO_ALG4;
|
||||
break;
|
||||
default:
|
||||
mm_alg = CUSPARSE_MM_ALG_DEFAULT;
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
//std::cout << "step " << 1 << std::endl;
|
||||
|
||||
int64_t dim_k = mat2_shape1;
|
||||
|
||||
cudaMemset(result_p, 0.0, result->size);
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
cudaDataType cuda_data_type = getDtype<mat2_type>(mat2_p);
|
||||
|
||||
int* sorted_rows_ptr, *sorted_cols_ptr;
|
||||
float *sorted_vals_ptr;
|
||||
size_t sorted_rows_allocation, sorted_cols_allocation, sorted_vals_allocation;
|
||||
|
||||
if (!is_sorted) {
|
||||
sorted_rows_ptr = (int *)exe.allocator->alloc(2 * nnz * sizeof(int), sorted_rows_allocation);
|
||||
sorted_cols_ptr = sorted_rows_ptr + nnz;
|
||||
sorted_vals_ptr = (float *)exe.allocator->alloc(nnz * sizeof(float), sorted_vals_allocation);
|
||||
|
||||
cudaMemcpy(sorted_rows_ptr, rows_p, nnz * sizeof(int), cudaMemcpyDeviceToDevice);
|
||||
cudaMemcpy(sorted_cols_ptr, cols_p, nnz * sizeof(int), cudaMemcpyDeviceToDevice);
|
||||
cudaMemcpy(sorted_vals_ptr, vals_p, nnz * sizeof(float), cudaMemcpyDeviceToDevice);
|
||||
|
||||
thrust::sort_by_key(thrust::device,
|
||||
sorted_rows_ptr,
|
||||
sorted_rows_ptr + nnz,
|
||||
thrust::make_zip_iterator(
|
||||
thrust::make_tuple(
|
||||
sorted_cols_ptr,
|
||||
sorted_vals_ptr
|
||||
)));
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
} else {
|
||||
sorted_rows_ptr = rows_p;
|
||||
sorted_cols_ptr = cols_p;
|
||||
sorted_vals_ptr = vals_p;
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
//std::cout << "step " << 2 << std::endl;
|
||||
|
||||
size_t workspace_buffer_size = 0;
|
||||
void *workspace_buffer = nullptr;
|
||||
|
||||
cusparseSpMatDescr_t sparse_descr;
|
||||
cusparseCreateCoo(
|
||||
&sparse_descr,
|
||||
dim_i, dim_j, nnz,
|
||||
(void*) sorted_rows_ptr,
|
||||
(void*) sorted_cols_ptr,
|
||||
(void*) sorted_vals_ptr,
|
||||
CUSPARSE_INDEX_32I,
|
||||
CUSPARSE_INDEX_BASE_ZERO, cuda_data_type);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
//std::cout << "step " << 3 << std::endl;
|
||||
|
||||
cusparseDnMatDescr_t dense_descr;
|
||||
cusparseCreateDnMat(&dense_descr,
|
||||
dim_k, dim_j, dim_k,
|
||||
(void*) mat2_p,
|
||||
cuda_data_type, CUSPARSE_ORDER_COL);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
//std::cout << "step " << 4 << std::endl;
|
||||
|
||||
cusparseDnMatDescr_t result_descr;
|
||||
cusparseCreateDnMat(&result_descr,
|
||||
dim_i, dim_k, dim_i,
|
||||
(void*) result_p,
|
||||
cuda_data_type, CUSPARSE_ORDER_COL);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
//std::cout << "step " << 5 << std::endl;
|
||||
|
||||
size_t required_workspace_buffer_size = 0;
|
||||
cusparseSpMM_bufferSize(
|
||||
handle,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
CUSPARSE_OPERATION_TRANSPOSE,
|
||||
(void*) &alpha,
|
||||
sparse_descr, dense_descr,
|
||||
(void*) &beta,
|
||||
result_descr,
|
||||
cuda_data_type, mm_alg,
|
||||
&required_workspace_buffer_size);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
//std::cout << "step " << 6 << std::endl;
|
||||
|
||||
if (required_workspace_buffer_size > workspace_buffer_size) {
|
||||
if (workspace_buffer != nullptr) {
|
||||
cudaFree(workspace_buffer);
|
||||
}
|
||||
workspace_buffer_size = required_workspace_buffer_size;
|
||||
cudaMallocManaged(&workspace_buffer, workspace_buffer_size);
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
//std::cout << "step " << 7 << std::endl;
|
||||
|
||||
cusparseSpMM(handle,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
CUSPARSE_OPERATION_TRANSPOSE,
|
||||
(void*) &alpha,
|
||||
sparse_descr, dense_descr,
|
||||
(void*) &beta,
|
||||
result_descr,
|
||||
cuda_data_type, mm_alg,
|
||||
workspace_buffer);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
//std::cout << "step " << 8 << std::endl;
|
||||
|
||||
cusparseDestroySpMat(sparse_descr);
|
||||
cusparseDestroyDnMat(dense_descr);
|
||||
cusparseDestroyDnMat(result_descr);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
//std::cout << "step " << 9 << std::endl;
|
||||
|
||||
if (!is_sorted) {
|
||||
exe.allocator->free(sorted_rows_ptr, 2 * nnz * sizeof(int), sorted_rows_allocation);
|
||||
exe.allocator->free(sorted_vals_ptr, nnz * sizeof(float), sorted_vals_allocation);
|
||||
}
|
||||
|
||||
if (workspace_buffer != nullptr) {
|
||||
cudaFree(workspace_buffer);
|
||||
}
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
""")
|
||||
result = result.t()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return result
|
||||
|
||||
class SPMM(Function):
|
||||
def execute(
|
||||
self,
|
||||
rows: jt.Var,
|
||||
cols: jt.Var,
|
||||
vals: jt.Var,
|
||||
size: Union[Tuple[int, int], jt.NanoVector],
|
||||
mat: jt.Var,
|
||||
cuda_spmm_alg: int = 1,
|
||||
):
|
||||
size = tuple(size)
|
||||
self.save_vars = rows, cols, vals, size, mat, cuda_spmm_alg
|
||||
result = spmm(
|
||||
rows,
|
||||
cols,
|
||||
vals,
|
||||
size,
|
||||
mat,
|
||||
is_sorted=False,
|
||||
cuda_spmm_alg=cuda_spmm_alg,
|
||||
)
|
||||
return result
|
||||
|
||||
def grad(
|
||||
self,
|
||||
grad: jt.Var
|
||||
):
|
||||
rows, cols, vals, size, mat, cuda_spmm_alg = self.save_vars
|
||||
new_size = (size[1], size[0])
|
||||
vals_grad = grad.matmul(mat.t())
|
||||
mat_grad = spmm(
|
||||
cols,
|
||||
rows,
|
||||
vals,
|
||||
new_size,
|
||||
grad,
|
||||
is_sorted=False,
|
||||
cuda_spmm_alg=cuda_spmm_alg
|
||||
)
|
||||
|
||||
return None, None, vals_grad, None, mat_grad, None
|
Loading…
Reference in New Issue