refactor: wgpu reductions (#471)

This commit is contained in:
Nathaniel Simard 2023-07-06 11:40:37 -04:00 committed by GitHub
parent d78f25f922
commit 04ad14a32a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 162 additions and 105 deletions

View File

@ -634,7 +634,10 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The mean of all elements in the tensor.
fn int_mean<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1>;
fn int_mean<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1> {
let num_elems = B::int_shape(&tensor).num_elements();
B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
}
/// Computes the mean of all elements in the tensor along a dimension.
///

View File

@ -716,7 +716,10 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// A scalar tensor with the mean of all elements in `tensor`.
fn mean<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
fn mean<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
let num_elems = B::shape(&tensor).num_elements();
B::div_scalar(B::sum(tensor), (num_elems as i64).elem())
}
/// Mean of all elements in a tensor along a dimension.
///

View File

@ -47,6 +47,12 @@ pub struct WorkGroup {
pub z: u32,
}
impl WorkGroup {
pub fn num_invocations(&self) -> usize {
(self.x * self.y * self.z) as usize
}
}
impl Context {
/// Create a new context where computing tasks will be executed on the given
/// [device](WgpuDevice).

View File

@ -99,6 +99,10 @@ impl<
.register("workgroup_size_x", WORKGROUP_X_SIZE.to_string())
.register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string())
.register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string())
.register(
"workgroup_size",
(WORKGROUP_X_SIZE * WORKGROUP_Y_SIZE * WORKGROUP_Z_SIZE).to_string(),
)
.register("elem", E::type_name())
.register("int", I::type_name())
}
@ -123,6 +127,10 @@ impl<K: StaticKernel, E: WgpuElement, I: WgpuElement> DynamicKernel
.register("workgroup_size_x", self.workgroup_x_size.to_string())
.register("workgroup_size_y", self.workgroup_y_size.to_string())
.register("workgroup_size_z", self.workgroup_z_size.to_string())
.register(
"workgroup_size",
(self.workgroup_x_size * self.workgroup_y_size * self.workgroup_z_size).to_string(),
)
.register("elem", E::type_name())
.register("int", I::type_name())
}

View File

@ -1,5 +1,5 @@
use super::{build_info, KernelSettings, SourceTemplate, StaticKernel};
use crate::{context::WorkGroup, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
use crate::{element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, tensor::WgpuTensor};
use burn_tensor::Shape;
kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl");
@ -13,7 +13,7 @@ pub struct MeanDim;
impl StaticKernel for SumDim {
fn source_template() -> SourceTemplate {
ReductionDimRaw::source_template().register("assign", "output[global_id.x] = sum;")
ReductionDimRaw::source_template().register("assign", "output[id] = sum;")
}
}
@ -25,7 +25,7 @@ impl StaticKernel for MeanDim {
return sum / {{ elem }}(dim);
}",
)
.register("assign", "output[global_id.x] = mean_dim(sum, shape_dim);")
.register("assign", "output[id] = mean_dim(sum, shape_dim);")
}
}
@ -45,50 +45,70 @@ impl StaticKernel for ArgsMin {
}
}
pub fn reduction_sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
const WORKGROUP: usize = 256;
/// Sum all elements in the input buffer.
pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
const WORKGROUP: usize = 32;
let mut input_buffer = input.buffer;
let mut num_invocations =
f32::ceil(input.shape.num_elements() as f32 / WORKGROUP as f32) as usize;
let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP);
let kernel = input
.context
.compile_static::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, 1, 1>>();
.compile_static::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
loop {
let num_invocations = workgroup.num_invocations();
let buffer = input
.context
.create_buffer(core::mem::size_of::<E>() * num_invocations);
let workgroup = WorkGroup::new(num_invocations as u32, 1, 1);
input
.context
.execute(workgroup, kernel.clone(), &[&input_buffer, &buffer]);
.execute(workgroup.clone(), kernel.clone(), &[&input_buffer, &buffer]);
if num_invocations == 1 {
if num_invocations <= 1 {
return WgpuTensor::new(input.context, Shape::new([1]), buffer);
}
input_buffer = buffer;
num_invocations = f32::ceil(num_invocations as f32 / WORKGROUP as f32) as usize;
workgroup = elemwise_workgroup(num_invocations, WORKGROUP);
}
}
pub fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
/// Execute the sum dim kernel.
pub fn sum_dim<E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_dim::<SumDim, E, D>(input, dim)
}
/// Execute the mean dim kernel.
pub fn mean_dim<E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_dim::<MeanDim, E, D>(input, dim)
}
fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
const WORKGROUP: usize = 32;
let mut shape_out = input.shape.clone();
shape_out.dims[dim] = 1;
let num_elems = shape_out.num_elements();
let buffer = input
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
let kernel = input
.context
.compile_static::<KernelSettings<K, E, i32, 256, 1, 1>>();
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
let info_buffers = input
@ -96,11 +116,7 @@ pub fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
.create_buffer_with_data(bytemuck::cast_slice(&info));
input.context.execute(
WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&input.buffer, &output.buffer, &info_buffers],
);
@ -108,20 +124,39 @@ pub fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
output
}
pub fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D: usize>(
/// Execute the argmax kernel.
pub fn argmax<E: WgpuElement, I: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
reduction_args_dim::<ArgsMax, E, I, D>(input, dim)
}
/// Execute the argmin kernel.
pub fn argmin<E: WgpuElement, I: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
reduction_args_dim::<ArgsMin, E, I, D>(input, dim)
}
fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
const WORKGROUP: usize = 32;
let mut shape_out = input.shape.clone();
shape_out.dims[dim] = 1;
let num_elems = shape_out.num_elements();
let buffer = input
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<I>());
.create_buffer(num_elems * core::mem::size_of::<I>());
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
let kernel = input
.context
.compile_static::<KernelSettings<K, E, I, 256, 1, 1>>();
.compile_static::<KernelSettings<K, E, I, WORKGROUP, WORKGROUP, 1>>();
let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
let info_buffers = input
@ -129,14 +164,53 @@ pub fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const
.create_buffer_with_data(bytemuck::cast_slice(&info));
input.context.execute(
WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
elemwise_workgroup(num_elems, WORKGROUP),
kernel,
&[&input.buffer, &output.buffer, &info_buffers],
);
WgpuTensor::new(output.context, output.shape, output.buffer)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{Distribution, Int, Tensor};
#[test]
fn reduction_sum_should_work_with_multiple_invocations() {
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let val = Tensor::<TestBackend, 1>::from_primitive(sum(tensor.into_primitive()));
let val_ref = tensor_ref.sum();
val_ref.into_data().assert_approx_eq(&val.into_data(), 3);
}
#[test]
fn reduction_sum_dim_should_work_with_multiple_invocations() {
let tensor = Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let val = Tensor::<TestBackend, 2>::from_primitive(reduction_dim::<SumDim, f32, 2>(
tensor.into_primitive(),
1,
));
let val_ref = tensor_ref.sum_dim(1);
val_ref.into_data().assert_approx_eq(&val.into_data(), 3);
}
#[test]
fn reduction_args_dim_should_work_with_multiple_invocations() {
let tensor = Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default);
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
let val = Tensor::<TestBackend, 2, Int>::from_primitive(argmax(tensor.into_primitive(), 1));
let val_ref = tensor_ref.argmax(1);
assert_eq!(val_ref.into_data().convert(), val.into_data());
}
}

View File

@ -295,19 +295,15 @@ where
}
fn sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
NumericOps::<G>::sum(tensor)
kernel::sum(tensor)
}
fn sum_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
NumericOps::<G>::sum_dim(tensor, dim)
}
fn mean<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
NumericOps::<G>::mean(tensor)
kernel::sum_dim(tensor, dim)
}
fn mean_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
NumericOps::<G>::mean_dim(tensor, dim)
kernel::mean_dim(tensor, dim)
}
fn to_full_precision<const D: usize>(
@ -427,10 +423,10 @@ where
}
fn argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::argmax(tensor, dim)
kernel::argmax(tensor, dim)
}
fn argmin<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::argmin(tensor, dim)
kernel::argmin(tensor, dim)
}
}

View File

@ -1,3 +1,4 @@
use super::{numeric::NumericOps, BaseOps, BoolTensor, Device, IntElem, IntTensor};
use crate::{
element::{FloatElement, IntElement},
kernel, GraphicsApi, WgpuBackend,
@ -5,8 +6,6 @@ use crate::{
use burn_tensor::{ops::IntTensorOps, Data, Shape};
use std::ops::Range;
use super::{numeric::NumericOps, BaseOps, BoolTensor, Device, IntElem, IntTensor};
impl<G, F, I> IntTensorOps<WgpuBackend<G, F, I>> for WgpuBackend<G, F, I>
where
G: GraphicsApi + 'static,
@ -254,25 +253,22 @@ where
}
fn int_sum<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
NumericOps::<G>::sum(tensor)
kernel::sum(tensor)
}
fn int_sum_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::sum_dim(tensor, dim)
}
fn int_mean<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
NumericOps::<G>::mean(tensor)
kernel::sum_dim(tensor, dim)
}
fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::mean_dim(tensor, dim)
kernel::mean_dim(tensor, dim)
}
fn int_argmax<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::argmax(tensor, dim)
kernel::argmax(tensor, dim)
}
fn int_argmin<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::argmin(tensor, dim)
kernel::argmin(tensor, dim)
}
}

View File

@ -1,7 +1,6 @@
use crate::kernel::{
binary_elemwise_default, binary_elemwise_inplace_default, reduction_args_dim, reduction_dim,
reduction_sum, unary_scalar_default, unary_scalar_inplace_default, ArgsMax, ArgsMin, MeanDim,
SumDim,
binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default,
unary_scalar_inplace_default,
};
use crate::pool::get_context;
use crate::{
@ -158,45 +157,4 @@ impl<G: GraphicsApi> NumericOps<G> {
unary_scalar_default::<DivScalar, E, D>(lhs, rhs)
}
pub fn sum<E: WgpuElement + Element, const D: usize>(
tensor: WgpuTensor<E, D>,
) -> WgpuTensor<E, 1> {
reduction_sum(tensor)
}
pub fn sum_dim<E: WgpuElement + Element, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_dim::<SumDim, E, D>(tensor, dim)
}
pub fn mean<E: WgpuElement + Element, const D: usize>(
tensor: WgpuTensor<E, D>,
) -> WgpuTensor<E, 1> {
let num_elems = tensor.shape.num_elements();
Self::div_scalar(Self::sum(tensor), (num_elems as f32).elem())
}
pub fn mean_dim<E: WgpuElement + Element, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_dim::<MeanDim, E, D>(tensor, dim)
}
pub fn argmax<E: WgpuElement + Element, I: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
reduction_args_dim::<ArgsMax, E, I, D>(tensor, dim)
}
pub fn argmin<E: WgpuElement + Element, I: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
reduction_args_dim::<ArgsMin, E, I, D>(tensor, dim)
}
}

View File

@ -10,11 +10,15 @@ var<storage, read_write> output: array<{{ int }}>;
@binding(2)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim: u32 = info[0];
let dim_reduce = info[4u * dim + 1u];
var index_offset: u32 = 0u;
@ -26,7 +30,7 @@ fn main(
let stride_output = info[i + dim];
let shape_output = info[i + 3u * dim];
let num_block = global_id.x / stride_output % shape_output;
let num_block = id / stride_output % shape_output;
if i - 1u != dim_reduce {
index_offset += num_block * stride_input;
@ -52,5 +56,5 @@ fn main(
}
}
output[global_id.x] = index;
output[id] = index;
}

View File

@ -1,4 +1,5 @@
const BLOCK_SIZE = {{ workgroup_size_x }}u;
const WORKGROUP_SIZE = {{ workgroup_size }}u;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@group(0)
@binding(0)
@ -8,26 +9,30 @@ var<storage, read> input: array<{{ elem }}>;
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
var<workgroup> data: array<{{ elem }}, BLOCK_SIZE>;
var<workgroup> data: array<{{ elem }}, WORKGROUP_SIZE>;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
data[local_id.x] = input[global_id.x];
let id_global = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let id_local = local_id.y * WORKGROUP_SIZE_X + local_id.x;
data[id_local] = input[id_global];
workgroupBarrier();
if local_id.x == 0u {
if id_local == 0u {
var sum = {{ elem }}(0);
for (var i: u32 = 0u; i < BLOCK_SIZE; i++) {
for (var i: u32 = 0u; i < WORKGROUP_SIZE; i++) {
sum += data[i];
}
output[workgroup_id.x] = sum;
let id_output = workgroup_id.y * num_workgroups.x + workgroup_id.x;
output[id_output] = sum;
}
}

View File

@ -10,11 +10,15 @@ var<storage, read_write> output: array<{{ elem }}>;
@binding(2)
var<storage, read> info: array<u32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let dim: u32 = info[0];
let dim_reduce = info[4u * dim + 1u];
var index_offset: u32 = 0u;
@ -26,7 +30,7 @@ fn main(
let stride_output = info[i + dim];
let shape_output = info[i + 3u * dim];
let num_block = global_id.x / stride_output % shape_output;
let num_block = id / stride_output % shape_output;
if i - 1u != dim_reduce {
index_offset += num_block * stride_input;