151 lines
5.6 KiB
Python
151 lines
5.6 KiB
Python
import jittor as jt
|
|
from jittor import Function
|
|
|
|
from jsparse import SparseTensor, PointTensor
|
|
from jsparse.nn import functional as F
|
|
|
|
__all__ = ['spvoxelize', 'point_to_voxel']
|
|
|
|
class Voxelize(Function):
|
|
def execute(
|
|
self,
|
|
values: jt.Var,
|
|
idx_query: jt.Var,
|
|
counts: jt.Var
|
|
) -> jt.Var:
|
|
# N = values_shape0
|
|
# c = values_shape1
|
|
# N1 = counts_shape0
|
|
# out: N1 x c
|
|
output = jt.zeros((counts.shape[0], values.shape[1]), dtype='float32')
|
|
jt.code((0, ), 'float32', [values, idx_query, counts, output],
|
|
cuda_header="""
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <cuda_runtime.h>
|
|
""",
|
|
cuda_src="""
|
|
__global__ void voxelize_forward_kernel(@ARGS_DEF) {
|
|
@PRECALC
|
|
@alias(values, in0)
|
|
@alias(idx_query, in1)
|
|
@alias(counts, in2)
|
|
@alias(output, in3)
|
|
|
|
int index = blockDim.x * blockIdx.x + threadIdx.x;
|
|
int c = values_shape1;
|
|
int i = index / c;
|
|
int j = index % c;
|
|
|
|
if (i < values_shape0) {
|
|
int pos = @idx_query(i);
|
|
if (pos < 0 || pos >= counts_shape0 || @counts(pos) == 0) return;
|
|
atomicAdd(&@output(pos, j), @values(i, j) / (float)(@counts(pos)));
|
|
}
|
|
}
|
|
@alias(values, in0)
|
|
voxelize_forward_kernel<<< values_shape0, values_shape1 >>>(@ARGS);
|
|
""",
|
|
cpu_src="""
|
|
@alias(values, in0)
|
|
@alias(idx_query, in1)
|
|
@alias(counts, in2)
|
|
@alias(output, in3)
|
|
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < values_shape0; ++ i ) {
|
|
int pos = @idx_query(i);
|
|
if (@counts(pos) == 0)
|
|
continue;
|
|
#pragma omp parallel for
|
|
for (int j = 0; j < values_shape1; ++ j ) {
|
|
#pragma omp atomic
|
|
@output(pos, j) += @values(i, j) / (float)@counts(pos);
|
|
}
|
|
}
|
|
"""
|
|
).sync()
|
|
self.save_vars = idx_query, counts, values.shape[0]
|
|
return output
|
|
|
|
def grad(self, grad_output: jt.Var):
|
|
idx_query, counts, input_size = self.save_vars
|
|
grad_values = jt.zeros((input_size, grad_output.shape[1]), dtype='float32')
|
|
jt.code((0, ), 'float32', [idx_query, counts, grad_output, grad_values],
|
|
cuda_header="""
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <cuda_runtime.h>
|
|
""",
|
|
cuda_src="""
|
|
__global__ void voxelize_backward_kernel(@ARGS_DEF) {
|
|
@PRECALC
|
|
@alias(idx_query, in0)
|
|
@alias(counts, in1)
|
|
@alias(grad_output, in2)
|
|
@alias(grad_values, in3)
|
|
|
|
int index = blockDim.x * blockIdx.x + threadIdx.x;
|
|
int i = index / grad_output_shape1;
|
|
int j = index % grad_output_shape1;
|
|
if (i < grad_values_shape0) {
|
|
int pos = @idx_query(i);
|
|
if (pos < 0 || pos >= counts_shape0 || @counts(pos) == 0) return;
|
|
@grad_values(pos, j) = @grad_output(pos, j) /(float)@counts(pos);
|
|
}
|
|
}
|
|
|
|
voxelize_backward_kernel<<< grad_values_shape0, grad_values_shape1 >>>(@ARGS);
|
|
""",
|
|
cpu_src="""
|
|
@alias(idx_query, in0)
|
|
@alias(counts, in1)
|
|
@alias(grad_output, in2)
|
|
@alias(grad_values, in3)
|
|
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < grad_values_shape0; ++ i ) {
|
|
int pos = @idx_query(i);
|
|
if (@counts(pos) == 0) continue;
|
|
#pragma omp parallel for
|
|
for (int j = 0; j < grad_output_shape1; ++ j ) {
|
|
@grad_values(i, j) = @grad_output(pos, j) / (float)@counts(pos);
|
|
}
|
|
}
|
|
"""
|
|
).sync()
|
|
return grad_values, None, None
|
|
|
|
def spvoxelize(
|
|
values: jt.Var,
|
|
idx_query: jt.Var,
|
|
counts: jt.Var
|
|
) -> jt.Var:
|
|
return Voxelize.apply(values, idx_query, counts)
|
|
|
|
def point_to_voxel(x: SparseTensor, z: PointTensor) -> SparseTensor:
|
|
if z.additional_values is None or z.additional_values.get(
|
|
'idx_query') is None or z.additional_values['idx_query'].get(
|
|
x.stride) is None:
|
|
point_hash = F.sphash(
|
|
jt.concat([
|
|
z.indices[:, 0].int().view(-1, 1),
|
|
jt.floor(z.indices[:, 1:] / x.stride[0]).int() * x.stride[0]
|
|
], 1))
|
|
sparse_hash = F.sphash(x.indices)
|
|
idx_query = F.spquery(point_hash, sparse_hash).int()
|
|
counts = F.spcount(idx_query, x.indices.shape[0])
|
|
z.additional_values['idx_query'][x.stride] = idx_query
|
|
z.additional_values['counts'][x.stride] = counts
|
|
else:
|
|
idx_query = z.additional_values['idx_query'][x.stride]
|
|
counts = z.additional_values['counts'][x.stride]
|
|
|
|
voxelized_values = F.spvoxelize(z.values, idx_query, counts)
|
|
new_tensor = SparseTensor(voxelized_values, x.indices, x.stride, x.size, False)
|
|
new_tensor.cmaps = x.cmaps
|
|
new_tensor.kmaps = x.kmaps
|
|
|
|
return new_tensor
|
|
|