mirror of https://github.com/tracel-ai/burn.git
Migrate/jit/cat (#1457)
This commit is contained in:
parent
41d01b8e19
commit
cf3c1ca80a
|
@ -180,21 +180,3 @@ pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> Wor
|
|||
|
||||
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use core::any::TypeId;
|
||||
|
||||
#[test]
|
||||
fn test_kernel_type_id() {
|
||||
kernel_wgsl!(Cat, "../template/cat.wgsl");
|
||||
|
||||
let type_id_1 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 4>>();
|
||||
let type_id_2 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 5>>();
|
||||
let type_id_3 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 4>>();
|
||||
|
||||
assert_ne!(type_id_1, type_id_2);
|
||||
assert_eq!(type_id_1, type_id_3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::JitElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
|
||||
use super::WORKGROUP_DEFAULT;
|
||||
|
||||
kernel_wgsl!(Cat, "../template/cat.wgsl");
|
||||
|
||||
pub fn cat<R: Runtime, E: JitElement, const D: usize>(
|
||||
inputs: Vec<JitTensor<R, E, D>>,
|
||||
dim: usize,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let first_input = inputs.first().unwrap();
|
||||
let client = &first_input.client;
|
||||
let mut shape_output = first_input.shape.clone();
|
||||
shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum();
|
||||
|
||||
let buffer = first_input
|
||||
.client
|
||||
.empty(shape_output.num_elements() * std::mem::size_of::<E>());
|
||||
|
||||
let output = JitTensor::new(
|
||||
client.clone(),
|
||||
first_input.device.clone(),
|
||||
shape_output,
|
||||
buffer,
|
||||
);
|
||||
|
||||
let mut dim_cat_index = 0;
|
||||
|
||||
for input in inputs.iter() {
|
||||
let mut info = build_info(&[input, &output]);
|
||||
info.push(dim as u32);
|
||||
info.push(dim_cat_index as u32);
|
||||
dim_cat_index += input.shape.dims[dim];
|
||||
let info_buffer = client.create(bytemuck::cast_slice(&info));
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<Cat, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(
|
||||
input.shape.num_elements(),
|
||||
WORKGROUP_DEFAULT,
|
||||
));
|
||||
|
||||
client.execute(
|
||||
Box::new(kernel),
|
||||
&[&input.handle, &output.handle, &info_buffer],
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
mod base;
|
||||
mod binary;
|
||||
mod cast;
|
||||
mod cat;
|
||||
mod clamp;
|
||||
mod comparison;
|
||||
mod contiguous;
|
||||
|
@ -31,7 +30,6 @@ pub mod prng;
|
|||
/// Reduction algorithms
|
||||
pub mod reduce;
|
||||
|
||||
pub(crate) use cat::*;
|
||||
pub(crate) use clamp::*;
|
||||
pub(crate) use comparison::*;
|
||||
pub(crate) use index::*;
|
||||
|
|
|
@ -73,13 +73,6 @@ impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
|
|||
kernel::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn bool_cat<const D: usize>(
|
||||
tensors: Vec<BoolTensor<Self, D>>,
|
||||
dim: usize,
|
||||
) -> BoolTensor<Self, D> {
|
||||
kernel::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn bool_equal<const D: usize>(
|
||||
lhs: BoolTensor<Self, D>,
|
||||
rhs: BoolTensor<Self, D>,
|
||||
|
|
|
@ -461,13 +461,6 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
)
|
||||
}
|
||||
|
||||
fn float_cat<const D: usize>(
|
||||
tensors: Vec<FloatTensor<Self, D>>,
|
||||
dim: usize,
|
||||
) -> FloatTensor<Self, D> {
|
||||
kernel::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn float_argmax<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -109,10 +109,6 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
|
|||
kernel::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn int_cat<const D: usize>(tensors: Vec<IntTensor<Self, D>>, dim: usize) -> IntTensor<Self, D> {
|
||||
kernel::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn int_equal<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntTensor<Self, D>,
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim: u32 = info[0];
|
||||
let dim_cat = info[4u * dim + 1u];
|
||||
let dim_cat_index = info[4u * dim + 2u];
|
||||
|
||||
var num_elems = 1u;
|
||||
var index_input = 0u;
|
||||
var index_output = 0u;
|
||||
|
||||
for (var i: u32 = 1u; i <= dim; i++) {
|
||||
let stride_input = info[i];
|
||||
let stride_output = info[i + dim];
|
||||
let shape_input = info[i + 2u * dim];
|
||||
let shape_output = info[i + 3u * dim];
|
||||
|
||||
let num_block_output = id / stride_input % shape_input;
|
||||
index_input += num_block_output * stride_input;
|
||||
num_elems *= shape_input;
|
||||
|
||||
if i - 1u == dim_cat {
|
||||
index_output += (num_block_output + dim_cat_index) * stride_output;
|
||||
} else {
|
||||
index_output += num_block_output * stride_output;
|
||||
}
|
||||
}
|
||||
|
||||
if id < num_elems {
|
||||
output[index_output] = input[index_input];
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
use super::{BoolTensor, Device, FloatTensor, IntTensor};
|
||||
use super::{cat::cat_with_slice_assign, BoolTensor, Device, FloatTensor, IntTensor};
|
||||
use crate::{
|
||||
argwhere, backend::Backend, chunk, narrow, tensor::Shape, Bool, Data, ElementConversion,
|
||||
argwhere, backend::Backend, chunk, narrow, tensor::Shape, Bool, Data, ElementConversion, Tensor,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
|
@ -205,7 +205,16 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The tensor with the tensors concatenated along the given dimension.
|
||||
fn bool_cat<const D: usize>(tensors: Vec<BoolTensor<B, D>>, dim: usize) -> BoolTensor<B, D>;
|
||||
fn bool_cat<const D: usize>(tensors: Vec<BoolTensor<B, D>>, dim: usize) -> BoolTensor<B, D> {
|
||||
cat_with_slice_assign::<B, D, Bool>(
|
||||
tensors
|
||||
.into_iter()
|
||||
.map(Tensor::<B, D, Bool>::from_primitive)
|
||||
.collect(),
|
||||
dim,
|
||||
)
|
||||
.into_primitive()
|
||||
}
|
||||
|
||||
/// Equates the two tensors.
|
||||
///
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use super::cat::cat_with_slice_assign;
|
||||
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
use crate::Tensor;
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
|
||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||
use alloc::vec::Vec;
|
||||
|
@ -299,7 +301,16 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The concatenated tensor.
|
||||
fn int_cat<const D: usize>(tensors: Vec<IntTensor<B, D>>, dim: usize) -> IntTensor<B, D>;
|
||||
fn int_cat<const D: usize>(tensors: Vec<IntTensor<B, D>>, dim: usize) -> IntTensor<B, D> {
|
||||
cat_with_slice_assign::<B, D, Int>(
|
||||
tensors
|
||||
.into_iter()
|
||||
.map(Tensor::<B, D, Int>::from_primitive)
|
||||
.collect(),
|
||||
dim,
|
||||
)
|
||||
.into_primitive()
|
||||
}
|
||||
|
||||
/// Element-wise equality comparison.
|
||||
///
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
use crate::{backend::Backend, BasicOps, Tensor, TensorKind};
|
||||
use alloc::vec::Vec;
|
||||
|
||||
pub(crate) fn cat_with_slice_assign<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
|
||||
tensors: Vec<Tensor<B, D, K>>,
|
||||
dim: usize,
|
||||
) -> Tensor<B, D, K> {
|
||||
let first_tensor = tensors.first().expect("Tensors should not be empty");
|
||||
let mut shape = first_tensor.shape();
|
||||
let device = first_tensor.device();
|
||||
|
||||
let output_dim_length: usize = tensors
|
||||
.iter()
|
||||
.map(|tensor: &Tensor<B, D, K>| tensor.shape().dims[dim])
|
||||
.sum();
|
||||
shape.dims[dim] = output_dim_length;
|
||||
|
||||
let mut tensor_output = Tensor::empty(shape.clone(), &device);
|
||||
|
||||
let mut i = 0;
|
||||
let indices_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
start..end
|
||||
});
|
||||
|
||||
let mut output_index = 0;
|
||||
for tensor in tensors {
|
||||
let mut indices = indices_select_all.clone();
|
||||
let tensor_dim_length = tensor.shape().dims[dim];
|
||||
indices[dim] = output_index..output_index + tensor_dim_length;
|
||||
output_index += tensor_dim_length;
|
||||
|
||||
tensor_output = tensor_output.slice_assign(indices, tensor);
|
||||
}
|
||||
|
||||
tensor_output
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
/// Module with convolution operations.
|
||||
pub mod conv;
|
||||
|
||||
/// Module with cat operation
|
||||
pub(crate) mod cat;
|
||||
/// Module with unfold operations.
|
||||
pub(crate) mod unfold;
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use super::cat::cat_with_slice_assign;
|
||||
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
|
||||
use crate::Tensor;
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Float};
|
||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||
use alloc::vec::Vec;
|
||||
|
@ -1074,17 +1076,26 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
/// A tensor with the same shape as `tensor` with error function values.
|
||||
fn float_erf<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
|
||||
|
||||
/// Catcatenates tensors along a dimension.
|
||||
/// Concatenates tensors along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors` - The tensors to catcatenate.
|
||||
/// * `dim` - The dimension along which to catcatenate.
|
||||
/// * `tensors` - The tensors to concatenate.
|
||||
/// * `dim` - The dimension along which to concatenate.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the catcatenated tensors along `dim`.
|
||||
fn float_cat<const D: usize>(tensors: Vec<FloatTensor<B, D>>, dim: usize) -> FloatTensor<B, D>;
|
||||
/// A tensor with the concatenated tensors along `dim`.
|
||||
fn float_cat<const D: usize>(tensors: Vec<FloatTensor<B, D>>, dim: usize) -> FloatTensor<B, D> {
|
||||
cat_with_slice_assign::<B, D, Float>(
|
||||
tensors
|
||||
.into_iter()
|
||||
.map(Tensor::<B, D>::from_primitive)
|
||||
.collect(),
|
||||
dim,
|
||||
)
|
||||
.into_primitive()
|
||||
}
|
||||
|
||||
/// Gets the indices of the maximum elements of a tensor along an axis.
|
||||
///
|
||||
|
|
Loading…
Reference in New Issue