59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
from itertools import repeat
|
|
from typing import List, Tuple, Union
|
|
|
|
import jittor as jt
|
|
import numpy as np
|
|
|
|
from .utils import unique1d
|
|
|
|
__all__ = ['sparse_quantize', 'set_hash']
|
|
|
|
def set_hash(ndim, seed, low=100, high=1000):
|
|
jt.set_seed(seed)
|
|
return jt.randint(low, high, shape=(ndim + 1,), dtype='uint64')
|
|
|
|
def hash(x: np.ndarray, multiplier: np.ndarray) -> jt.Var:
|
|
assert x.ndim == 2, x.shape
|
|
|
|
x = x - x.min(dim=0)
|
|
x = x.uint64()
|
|
|
|
h = jt.zeros(x.shape[0], dtype='uint64')
|
|
for k in range(x.shape[1] - 1):
|
|
h += x[:, k]
|
|
h *= multiplier[k]
|
|
h += x[:, -1]
|
|
return h
|
|
|
|
def sparse_quantize(indices,
|
|
hash_multiplier,
|
|
voxel_size: Union[float, Tuple[float, ...]] = 1,
|
|
*,
|
|
return_index: bool = False,
|
|
return_inverse: bool = False,
|
|
return_count: bool = False) -> List[np.ndarray]:
|
|
if indices.dtype.is_int() and voxel_size == 1:
|
|
pass
|
|
else:
|
|
if isinstance(voxel_size, (float, int)):
|
|
voxel_size = tuple(repeat(voxel_size, 3))
|
|
assert isinstance(voxel_size, tuple) and len(voxel_size) == 3
|
|
|
|
voxel_size = jt.array(voxel_size)
|
|
indices[:, 1:] /= voxel_size
|
|
indices = jt.floor(indices).astype(jt.int32)
|
|
|
|
mapping, inverse_mapping, count = unique1d(hash(indices, hash_multiplier), inverse_mapping=True, count=True)
|
|
indices = indices[mapping]
|
|
|
|
outputs = [indices]
|
|
if return_index:
|
|
outputs += [mapping]
|
|
if return_inverse:
|
|
outputs += [inverse_mapping]
|
|
if return_count:
|
|
outputs += [count]
|
|
return outputs[0] if len(outputs) == 1 else outputs
|
|
|
|
|