Migrate/jit/cat (#1457)

This commit is contained in:
Louis Fortier-Dubois 2024-03-17 11:37:36 -04:00 committed by GitHub
parent 41d01b8e19
commit cf3c1ca80a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 81 additions and 154 deletions

View File

@ -180,21 +180,3 @@ pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> Wor
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
}
#[cfg(test)]
mod tests {
use super::*;
use core::any::TypeId;
#[test]
fn test_kernel_type_id() {
kernel_wgsl!(Cat, "../template/cat.wgsl");
let type_id_1 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 4>>();
let type_id_2 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 5>>();
let type_id_3 = TypeId::of::<KernelSettings<Cat, f32, i32, 2, 3, 4>>();
assert_ne!(type_id_1, type_id_2);
assert_eq!(type_id_1, type_id_3);
}
}

View File

@ -1,56 +0,0 @@
use crate::{
compute::StaticKernel,
element::JitElement,
kernel::{build_info, elemwise_workgroup, KernelSettings},
kernel_wgsl,
tensor::JitTensor,
Runtime,
};
use super::WORKGROUP_DEFAULT;
kernel_wgsl!(Cat, "../template/cat.wgsl");
pub fn cat<R: Runtime, E: JitElement, const D: usize>(
inputs: Vec<JitTensor<R, E, D>>,
dim: usize,
) -> JitTensor<R, E, D> {
let first_input = inputs.first().unwrap();
let client = &first_input.client;
let mut shape_output = first_input.shape.clone();
shape_output.dims[dim] = inputs.iter().map(|input| input.shape.dims[dim]).sum();
let buffer = first_input
.client
.empty(shape_output.num_elements() * std::mem::size_of::<E>());
let output = JitTensor::new(
client.clone(),
first_input.device.clone(),
shape_output,
buffer,
);
let mut dim_cat_index = 0;
for input in inputs.iter() {
let mut info = build_info(&[input, &output]);
info.push(dim as u32);
info.push(dim_cat_index as u32);
dim_cat_index += input.shape.dims[dim];
let info_buffer = client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<
KernelSettings<Cat, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(
input.shape.num_elements(),
WORKGROUP_DEFAULT,
));
client.execute(
Box::new(kernel),
&[&input.handle, &output.handle, &info_buffer],
);
}
output
}

View File

@ -1,7 +1,6 @@
mod base;
mod binary;
mod cast;
mod cat;
mod clamp;
mod comparison;
mod contiguous;
@ -31,7 +30,6 @@ pub mod prng;
/// Reduction algorithms
pub mod reduce;
pub(crate) use cat::*;
pub(crate) use clamp::*;
pub(crate) use comparison::*;
pub(crate) use index::*;

View File

@ -73,13 +73,6 @@ impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
kernel::slice_assign(tensor, ranges, value)
}
fn bool_cat<const D: usize>(
tensors: Vec<BoolTensor<Self, D>>,
dim: usize,
) -> BoolTensor<Self, D> {
kernel::cat(tensors, dim)
}
fn bool_equal<const D: usize>(
lhs: BoolTensor<Self, D>,
rhs: BoolTensor<Self, D>,

View File

@ -461,13 +461,6 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
)
}
fn float_cat<const D: usize>(
tensors: Vec<FloatTensor<Self, D>>,
dim: usize,
) -> FloatTensor<Self, D> {
kernel::cat(tensors, dim)
}
fn float_argmax<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,

View File

@ -109,10 +109,6 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
kernel::select_assign(tensor, dim, indices, value)
}
fn int_cat<const D: usize>(tensors: Vec<IntTensor<Self, D>>, dim: usize) -> IntTensor<Self, D> {
kernel::cat(tensors, dim)
}
fn int_equal<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,

View File

@ -1,51 +0,0 @@
@group(0)
@binding(0)
var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim: u32 = info[0];
let dim_cat = info[4u * dim + 1u];
let dim_cat_index = info[4u * dim + 2u];
var num_elems = 1u;
var index_input = 0u;
var index_output = 0u;
for (var i: u32 = 1u; i <= dim; i++) {
let stride_input = info[i];
let stride_output = info[i + dim];
let shape_input = info[i + 2u * dim];
let shape_output = info[i + 3u * dim];
let num_block_output = id / stride_input % shape_input;
index_input += num_block_output * stride_input;
num_elems *= shape_input;
if i - 1u == dim_cat {
index_output += (num_block_output + dim_cat_index) * stride_output;
} else {
index_output += num_block_output * stride_output;
}
}
if id < num_elems {
output[index_output] = input[index_input];
}
}

View File

@ -1,6 +1,6 @@
use super::{BoolTensor, Device, FloatTensor, IntTensor};
use super::{cat::cat_with_slice_assign, BoolTensor, Device, FloatTensor, IntTensor};
use crate::{
argwhere, backend::Backend, chunk, narrow, tensor::Shape, Bool, Data, ElementConversion,
argwhere, backend::Backend, chunk, narrow, tensor::Shape, Bool, Data, ElementConversion, Tensor,
};
use alloc::vec::Vec;
use burn_common::reader::Reader;
@ -205,7 +205,16 @@ pub trait BoolTensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the tensors concatenated along the given dimension.
fn bool_cat<const D: usize>(tensors: Vec<BoolTensor<B, D>>, dim: usize) -> BoolTensor<B, D>;
fn bool_cat<const D: usize>(tensors: Vec<BoolTensor<B, D>>, dim: usize) -> BoolTensor<B, D> {
cat_with_slice_assign::<B, D, Bool>(
tensors
.into_iter()
.map(Tensor::<B, D, Bool>::from_primitive)
.collect(),
dim,
)
.into_primitive()
}
/// Equates the two tensors.
///

View File

@ -1,4 +1,6 @@
use super::cat::cat_with_slice_assign;
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use crate::Tensor;
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
use crate::{tensor::api::chunk, tensor::api::narrow};
use alloc::vec::Vec;
@ -299,7 +301,16 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The concatenated tensor.
fn int_cat<const D: usize>(tensors: Vec<IntTensor<B, D>>, dim: usize) -> IntTensor<B, D>;
fn int_cat<const D: usize>(tensors: Vec<IntTensor<B, D>>, dim: usize) -> IntTensor<B, D> {
cat_with_slice_assign::<B, D, Int>(
tensors
.into_iter()
.map(Tensor::<B, D, Int>::from_primitive)
.collect(),
dim,
)
.into_primitive()
}
/// Element-wise equality comparison.
///

View File

@ -0,0 +1,39 @@
use crate::{backend::Backend, BasicOps, Tensor, TensorKind};
use alloc::vec::Vec;
pub(crate) fn cat_with_slice_assign<B: Backend, const D: usize, K: TensorKind<B> + BasicOps<B>>(
tensors: Vec<Tensor<B, D, K>>,
dim: usize,
) -> Tensor<B, D, K> {
let first_tensor = tensors.first().expect("Tensors should not be empty");
let mut shape = first_tensor.shape();
let device = first_tensor.device();
let output_dim_length: usize = tensors
.iter()
.map(|tensor: &Tensor<B, D, K>| tensor.shape().dims[dim])
.sum();
shape.dims[dim] = output_dim_length;
let mut tensor_output = Tensor::empty(shape.clone(), &device);
let mut i = 0;
let indices_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
start..end
});
let mut output_index = 0;
for tensor in tensors {
let mut indices = indices_select_all.clone();
let tensor_dim_length = tensor.shape().dims[dim];
indices[dim] = output_index..output_index + tensor_dim_length;
output_index += tensor_dim_length;
tensor_output = tensor_output.slice_assign(indices, tensor);
}
tensor_output
}

View File

@ -1,6 +1,8 @@
/// Module with convolution operations.
pub mod conv;
/// Module with cat operation
pub(crate) mod cat;
/// Module with unfold operations.
pub(crate) mod unfold;

View File

@ -1,4 +1,6 @@
use super::cat::cat_with_slice_assign;
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
use crate::Tensor;
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Float};
use crate::{tensor::api::chunk, tensor::api::narrow};
use alloc::vec::Vec;
@ -1074,17 +1076,26 @@ pub trait FloatTensorOps<B: Backend> {
/// A tensor with the same shape as `tensor` with error function values.
fn float_erf<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Catcatenates tensors along a dimension.
/// Concatenates tensors along a dimension.
///
/// # Arguments
///
/// * `tensors` - The tensors to catcatenate.
/// * `dim` - The dimension along which to catcatenate.
/// * `tensors` - The tensors to concatenate.
/// * `dim` - The dimension along which to concatenate.
///
/// # Returns
///
/// A tensor with the catcatenated tensors along `dim`.
fn float_cat<const D: usize>(tensors: Vec<FloatTensor<B, D>>, dim: usize) -> FloatTensor<B, D>;
/// A tensor with the concatenated tensors along `dim`.
fn float_cat<const D: usize>(tensors: Vec<FloatTensor<B, D>>, dim: usize) -> FloatTensor<B, D> {
cat_with_slice_assign::<B, D, Float>(
tensors
.into_iter()
.map(Tensor::<B, D>::from_primitive)
.collect(),
dim,
)
.into_primitive()
}
/// Gets the indices of the maximum elements of a tensor along an axis.
///