192 lines
7.4 KiB
Python
192 lines
7.4 KiB
Python
import jittor as jt
|
|
|
|
__all__ = ['unique']
|
|
|
|
|
|
def unique(
|
|
input: jt.Var,
|
|
return_inverse: bool=False,
|
|
return_counts: bool=False,
|
|
dim: int=None):
|
|
|
|
temp_shape = None
|
|
if dim == None:
|
|
temp_shape = list(input.shape)
|
|
input_flatten = input.flatten()
|
|
dim = 0
|
|
else:
|
|
input_flatten = input
|
|
|
|
input_flatten = input_flatten.transpose(dim, 0)
|
|
orig_shape = input_flatten.shape
|
|
input_flatten = input_flatten.view(orig_shape[0], -1)
|
|
|
|
indice = jt.code((input_flatten.shape[0], ), 'int32', [input_flatten],
|
|
cpu_header='''
|
|
#include <algorithm>
|
|
''',
|
|
cpu_src='''
|
|
@alias(input_flatten, in0)
|
|
@alias(indice, out)
|
|
|
|
int dimlen = input_flatten_shape0, dimsize = input_flatten_shape1;
|
|
for(int i = 0; i < dimlen; ++i) @indice(i) = i;
|
|
std::sort(&@indice(0), &@indice(dimlen), [&](int a, int b){
|
|
for(int i = 0; i < dimsize; ++i) {
|
|
int lhs = @input_flatten(a, i), rhs = @input_flatten(b, i);
|
|
if (lhs != rhs) return lhs < rhs;
|
|
}
|
|
return false;
|
|
});
|
|
''',
|
|
cuda_header='''
|
|
#undef out
|
|
#include <thrust/extrema.h>
|
|
#include <thrust/device_ptr.h>
|
|
#include <thrust/execution_policy.h>
|
|
#include <thrust/device_vector.h>
|
|
#include <thrust/sequence.h>
|
|
|
|
#include <cub/cub.cuh>
|
|
#include <executor.h>
|
|
''',
|
|
cuda_src=
|
|
'''
|
|
@alias(input_flatten, in0)
|
|
@alias(indice, out)
|
|
int dimlen = indice_shape0, dimsize = input_flatten_shape1;
|
|
|
|
if (dimsize == 1) {
|
|
size_t raw_allocation, d_allocation, temp_storage_bytes = 0;
|
|
void *d_temp_storage = NULL;
|
|
int* raw_ptr = (int*)exe.allocator->alloc(dimlen * (sizeof(int) + sizeof(input_flatten_type)), raw_allocation);
|
|
|
|
thrust::device_ptr<int32_t> arange_ptr = thrust::device_pointer_cast(raw_ptr);
|
|
thrust::sequence(arange_ptr, arange_ptr + dimlen);
|
|
|
|
cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, input_flatten_p, raw_ptr + dimlen, thrust::raw_pointer_cast(arange_ptr), indice_p, dimlen);
|
|
d_temp_storage = exe.allocator->alloc(temp_storage_bytes, d_allocation);
|
|
cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, input_flatten_p, raw_ptr + dimlen, thrust::raw_pointer_cast(arange_ptr), indice_p, dimlen);
|
|
|
|
exe.allocator->free(raw_ptr, dimlen * (sizeof(int) + sizeof(input_flatten_type)), raw_allocation);
|
|
exe.allocator->free(d_temp_storage, temp_storage_bytes, d_allocation);
|
|
} else {
|
|
thrust::device_ptr<input_flatten_type> input_ptr = thrust::device_pointer_cast(input_flatten_p);
|
|
thrust::device_ptr<int32_t> indice_ptr = thrust::device_pointer_cast(indice_p);
|
|
|
|
thrust::sequence(indice_ptr, indice_ptr + dimlen);
|
|
thrust::sort(thrust::device, indice_ptr, indice_ptr + dimlen,
|
|
[=] __device__ (int32_t a, int32_t b)->bool {
|
|
for(int i = 0; i < dimsize; ++i) {
|
|
input_flatten_type lhs = input_ptr[i + a * dimsize],
|
|
rhs = input_ptr[i + b * dimsize];
|
|
if (lhs != rhs) return lhs < rhs;
|
|
}
|
|
return false;
|
|
});
|
|
}
|
|
'''
|
|
)
|
|
input_sorted = input_flatten[indice][:]
|
|
|
|
dimlen = indice.shape[0]
|
|
|
|
diff = jt.logical_not(jt.all(input_sorted[1:] == input_sorted[: -1], 1))
|
|
diff = jt.concat([jt.array([False], dtype='bool'), diff], 0)
|
|
diff = jt.array(diff, dtype = jt.int32)
|
|
|
|
output, inverse = jt.code(
|
|
[(-input_sorted.shape[0], ), (indice.shape)],
|
|
[input_sorted.dtype, indice.dtype],
|
|
[input_sorted, diff, indice],
|
|
cpu_header='''
|
|
#include <algorithm>
|
|
@alias(input_sorted, in0)
|
|
@alias(diff, in1)
|
|
@alias(indice, in2)
|
|
@alias(output, out0)
|
|
@alias(inverse, out1)
|
|
''',
|
|
cpu_src=
|
|
f"bool return_inverse = {int(return_inverse)};" +
|
|
'''
|
|
int tot = -1;
|
|
bool return_inverse = @out2(0);
|
|
for (int i = 0; i < input_sorted_shape0; ++i) {
|
|
if (i == 0 || @diff(i)) {
|
|
++tot; @output(tot) = i;
|
|
}
|
|
if (return_inverse)
|
|
@inverse(@indice(i)) = tot;
|
|
}
|
|
output->set_shape({tot + 1});
|
|
''',
|
|
cuda_header='''
|
|
#undef out
|
|
|
|
#include <thrust/extrema.h>
|
|
#include <thrust/device_ptr.h>
|
|
#include <thrust/execution_policy.h>
|
|
#include <thrust/scan.h>
|
|
#include <executor.h>
|
|
|
|
@alias(input_sorted, in0)
|
|
@alias(diff, in1)
|
|
@alias(indice, in2)
|
|
@alias(output, out0)
|
|
@alias(inverse, out1)
|
|
''',
|
|
cuda_src=
|
|
f"bool return_inverse = {int(return_inverse)};" +
|
|
'''
|
|
int dimlen = input_sorted_shape0, dimsize = input_sorted_shape1;
|
|
size_t raw_allocation;
|
|
int* raw_ptr = (int*)exe.allocator->alloc(2 * dimlen * sizeof(int), raw_allocation);
|
|
|
|
thrust::device_ptr<int32_t> diff_ptr = thrust::device_pointer_cast(diff_p),
|
|
inverse_ptr = thrust::device_pointer_cast(inverse_p),
|
|
array_ptr = thrust::device_pointer_cast(raw_ptr),
|
|
sum_ptr = thrust::device_pointer_cast(raw_ptr + dimlen),
|
|
indice_ptr = thrust::device_pointer_cast(indice_p);
|
|
thrust::device_ptr<input_sorted_type> input_ptr = thrust::device_pointer_cast(input_sorted_p);
|
|
|
|
if (return_inverse) {
|
|
thrust::inclusive_scan(diff_ptr, diff_ptr + dimlen, sum_ptr);
|
|
thrust::scatter(sum_ptr, sum_ptr + dimlen, indice_ptr, inverse_ptr);
|
|
}
|
|
|
|
thrust::sequence(array_ptr, array_ptr + dimlen);
|
|
int num = thrust::unique(array_ptr, array_ptr + dimlen,
|
|
[=] __device__ (int32_t a, int32_t b)->bool {
|
|
for(int i = 0; i < dimsize; ++i) {
|
|
input_sorted_type
|
|
lhs = input_ptr[i + a * dimsize],
|
|
rhs = input_ptr[i + b * dimsize];
|
|
if (lhs != rhs) return false;
|
|
}
|
|
return true;
|
|
}) - array_ptr;
|
|
|
|
cudaMemcpy(output_p, raw_ptr, sizeof(int) * num, cudaMemcpyDeviceToDevice);
|
|
exe.allocator->free(raw_ptr, 2 * dimlen * sizeof(int), raw_allocation);
|
|
output->set_shape({ num });
|
|
'''
|
|
)
|
|
indice_shape = (output.shape[0], )
|
|
output = input_sorted[output][:]
|
|
|
|
new_shape = list(orig_shape[1:])
|
|
new_shape.insert(0, -1)
|
|
output = output.view(new_shape).transpose(dim, 0)
|
|
if temp_shape != None:
|
|
inverse = inverse.view(temp_shape).transpose(dim, 0)
|
|
|
|
if return_inverse:
|
|
if return_counts:
|
|
counts = jt.zeros(indice_shape, dtype=jt.int32)
|
|
jt.scatter_(counts, 0, inverse.flatten(), jt.ones(dimlen), reduce='add')
|
|
return output, inverse, counts
|
|
else:
|
|
return output, inverse
|
|
else:
|
|
return output |