From 04ad14a32a7bbafbc6befc60415445d1514ec4af Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Thu, 6 Jul 2023 11:40:37 -0400 Subject: [PATCH] refactor: wgpu reductions (#471) --- burn-tensor/src/tensor/ops/int_tensor.rs | 5 +- burn-tensor/src/tensor/ops/tensor.rs | 5 +- burn-wgpu/src/context/base.rs | 6 + burn-wgpu/src/kernel/base.rs | 8 ++ burn-wgpu/src/kernel/reduction.rs | 130 ++++++++++++++---- burn-wgpu/src/ops/float_ops.rs | 14 +- burn-wgpu/src/ops/int_ops.rs | 16 +-- burn-wgpu/src/ops/numeric.rs | 46 +------ burn-wgpu/src/template/reduction/args.wgsl | 10 +- .../src/template/reduction/recursive_sum.wgsl | 19 ++- .../src/template/reduction/reduce_dim.wgsl | 8 +- 11 files changed, 162 insertions(+), 105 deletions(-) diff --git a/burn-tensor/src/tensor/ops/int_tensor.rs b/burn-tensor/src/tensor/ops/int_tensor.rs index 4db6488d8..b1580c4d9 100644 --- a/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/burn-tensor/src/tensor/ops/int_tensor.rs @@ -634,7 +634,10 @@ pub trait IntTensorOps { /// # Returns /// /// The mean of all elements in the tensor. - fn int_mean(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1>; + fn int_mean(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive<1> { + let num_elems = B::int_shape(&tensor).num_elements(); + B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem()) + } /// Computes the mean of all elements in the tensor along a dimension. /// diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index 18b13d0e4..2321cc242 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -716,7 +716,10 @@ pub trait TensorOps { /// # Returns /// /// A scalar tensor with the mean of all elements in `tensor`. - fn mean(tensor: B::TensorPrimitive) -> B::TensorPrimitive<1>; + fn mean(tensor: B::TensorPrimitive) -> B::TensorPrimitive<1> { + let num_elems = B::shape(&tensor).num_elements(); + B::div_scalar(B::sum(tensor), (num_elems as i64).elem()) + } /// Mean of all elements in a tensor along a dimension. /// diff --git a/burn-wgpu/src/context/base.rs b/burn-wgpu/src/context/base.rs index 18a731974..f64ba6b14 100644 --- a/burn-wgpu/src/context/base.rs +++ b/burn-wgpu/src/context/base.rs @@ -47,6 +47,12 @@ pub struct WorkGroup { pub z: u32, } +impl WorkGroup { + pub fn num_invocations(&self) -> usize { + (self.x * self.y * self.z) as usize + } +} + impl Context { /// Create a new context where computing tasks will be executed on the given /// [device](WgpuDevice). diff --git a/burn-wgpu/src/kernel/base.rs b/burn-wgpu/src/kernel/base.rs index be54e3e95..c8c68e15b 100644 --- a/burn-wgpu/src/kernel/base.rs +++ b/burn-wgpu/src/kernel/base.rs @@ -99,6 +99,10 @@ impl< .register("workgroup_size_x", WORKGROUP_X_SIZE.to_string()) .register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string()) .register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string()) + .register( + "workgroup_size", + (WORKGROUP_X_SIZE * WORKGROUP_Y_SIZE * WORKGROUP_Z_SIZE).to_string(), + ) .register("elem", E::type_name()) .register("int", I::type_name()) } @@ -123,6 +127,10 @@ impl DynamicKernel .register("workgroup_size_x", self.workgroup_x_size.to_string()) .register("workgroup_size_y", self.workgroup_y_size.to_string()) .register("workgroup_size_z", self.workgroup_z_size.to_string()) + .register( + "workgroup_size", + (self.workgroup_x_size * self.workgroup_y_size * self.workgroup_z_size).to_string(), + ) .register("elem", E::type_name()) .register("int", I::type_name()) } diff --git a/burn-wgpu/src/kernel/reduction.rs b/burn-wgpu/src/kernel/reduction.rs index dbf58c469..d83d287f8 100644 --- a/burn-wgpu/src/kernel/reduction.rs +++ b/burn-wgpu/src/kernel/reduction.rs @@ -1,5 +1,5 @@ use super::{build_info, KernelSettings, SourceTemplate, StaticKernel}; -use crate::{context::WorkGroup, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; +use crate::{element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, tensor::WgpuTensor}; use burn_tensor::Shape; kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl"); @@ -13,7 +13,7 @@ pub struct MeanDim; impl StaticKernel for SumDim { fn source_template() -> SourceTemplate { - ReductionDimRaw::source_template().register("assign", "output[global_id.x] = sum;") + ReductionDimRaw::source_template().register("assign", "output[id] = sum;") } } @@ -25,7 +25,7 @@ impl StaticKernel for MeanDim { return sum / {{ elem }}(dim); }", ) - .register("assign", "output[global_id.x] = mean_dim(sum, shape_dim);") + .register("assign", "output[id] = mean_dim(sum, shape_dim);") } } @@ -45,50 +45,70 @@ impl StaticKernel for ArgsMin { } } -pub fn reduction_sum(input: WgpuTensor) -> WgpuTensor { - const WORKGROUP: usize = 256; +/// Sum all elements in the input buffer. +pub fn sum(input: WgpuTensor) -> WgpuTensor { + const WORKGROUP: usize = 32; let mut input_buffer = input.buffer; - let mut num_invocations = - f32::ceil(input.shape.num_elements() as f32 / WORKGROUP as f32) as usize; + let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP); let kernel = input .context - .compile_static::>(); + .compile_static::>(); loop { + let num_invocations = workgroup.num_invocations(); let buffer = input .context .create_buffer(core::mem::size_of::() * num_invocations); - let workgroup = WorkGroup::new(num_invocations as u32, 1, 1); input .context - .execute(workgroup, kernel.clone(), &[&input_buffer, &buffer]); + .execute(workgroup.clone(), kernel.clone(), &[&input_buffer, &buffer]); - if num_invocations == 1 { + if num_invocations <= 1 { return WgpuTensor::new(input.context, Shape::new([1]), buffer); } input_buffer = buffer; - num_invocations = f32::ceil(num_invocations as f32 / WORKGROUP as f32) as usize; + workgroup = elemwise_workgroup(num_invocations, WORKGROUP); } } -pub fn reduction_dim( +/// Execute the sum dim kernel. +pub fn sum_dim( input: WgpuTensor, dim: usize, ) -> WgpuTensor { + reduction_dim::(input, dim) +} + +/// Execute the mean dim kernel. +pub fn mean_dim( + input: WgpuTensor, + dim: usize, +) -> WgpuTensor { + reduction_dim::(input, dim) +} + +fn reduction_dim( + input: WgpuTensor, + dim: usize, +) -> WgpuTensor { + const WORKGROUP: usize = 32; + let mut shape_out = input.shape.clone(); shape_out.dims[dim] = 1; + let num_elems = shape_out.num_elements(); let buffer = input .context - .create_buffer(shape_out.num_elements() * core::mem::size_of::()); + .create_buffer(num_elems * core::mem::size_of::()); let output = WgpuTensor::new(input.context.clone(), shape_out, buffer); let kernel = input .context - .compile_static::>(); + .compile_static::>(); + let mut info = build_info(&[&input, &output]); info.push(dim as u32); let info_buffers = input @@ -96,11 +116,7 @@ pub fn reduction_dim( .create_buffer_with_data(bytemuck::cast_slice(&info)); input.context.execute( - WorkGroup::new( - f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32, - 1, - 1, - ), + elemwise_workgroup(num_elems, WORKGROUP), kernel, &[&input.buffer, &output.buffer, &info_buffers], ); @@ -108,20 +124,39 @@ pub fn reduction_dim( output } -pub fn reduction_args_dim( +/// Execute the argmax kernel. +pub fn argmax( input: WgpuTensor, dim: usize, ) -> WgpuTensor { + reduction_args_dim::(input, dim) +} + +/// Execute the argmin kernel. +pub fn argmin( + input: WgpuTensor, + dim: usize, +) -> WgpuTensor { + reduction_args_dim::(input, dim) +} + +fn reduction_args_dim( + input: WgpuTensor, + dim: usize, +) -> WgpuTensor { + const WORKGROUP: usize = 32; + let mut shape_out = input.shape.clone(); shape_out.dims[dim] = 1; + let num_elems = shape_out.num_elements(); let buffer = input .context - .create_buffer(shape_out.num_elements() * core::mem::size_of::()); + .create_buffer(num_elems * core::mem::size_of::()); let output = WgpuTensor::new(input.context.clone(), shape_out, buffer); let kernel = input .context - .compile_static::>(); + .compile_static::>(); let mut info = build_info(&[&input, &output]); info.push(dim as u32); let info_buffers = input @@ -129,14 +164,53 @@ pub fn reduction_args_dim::random([6, 256], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let val = Tensor::::from_primitive(sum(tensor.into_primitive())); + let val_ref = tensor_ref.sum(); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_sum_dim_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 1024], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let val = Tensor::::from_primitive(reduction_dim::( + tensor.into_primitive(), + 1, + )); + let val_ref = tensor_ref.sum_dim(1); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_args_dim_should_work_with_multiple_invocations() { + let tensor = Tensor::::random([6, 1024], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let val = Tensor::::from_primitive(argmax(tensor.into_primitive(), 1)); + let val_ref = tensor_ref.argmax(1); + + assert_eq!(val_ref.into_data().convert(), val.into_data()); + } +} diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs index d01689545..5c2da7b5e 100644 --- a/burn-wgpu/src/ops/float_ops.rs +++ b/burn-wgpu/src/ops/float_ops.rs @@ -295,19 +295,15 @@ where } fn sum(tensor: FloatTensor) -> FloatTensor { - NumericOps::::sum(tensor) + kernel::sum(tensor) } fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - NumericOps::::sum_dim(tensor, dim) - } - - fn mean(tensor: FloatTensor) -> FloatTensor { - NumericOps::::mean(tensor) + kernel::sum_dim(tensor, dim) } fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - NumericOps::::mean_dim(tensor, dim) + kernel::mean_dim(tensor, dim) } fn to_full_precision( @@ -427,10 +423,10 @@ where } fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - NumericOps::::argmax(tensor, dim) + kernel::argmax(tensor, dim) } fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - NumericOps::::argmin(tensor, dim) + kernel::argmin(tensor, dim) } } diff --git a/burn-wgpu/src/ops/int_ops.rs b/burn-wgpu/src/ops/int_ops.rs index cbad77b49..9e1647a64 100644 --- a/burn-wgpu/src/ops/int_ops.rs +++ b/burn-wgpu/src/ops/int_ops.rs @@ -1,3 +1,4 @@ +use super::{numeric::NumericOps, BaseOps, BoolTensor, Device, IntElem, IntTensor}; use crate::{ element::{FloatElement, IntElement}, kernel, GraphicsApi, WgpuBackend, @@ -5,8 +6,6 @@ use crate::{ use burn_tensor::{ops::IntTensorOps, Data, Shape}; use std::ops::Range; -use super::{numeric::NumericOps, BaseOps, BoolTensor, Device, IntElem, IntTensor}; - impl IntTensorOps> for WgpuBackend where G: GraphicsApi + 'static, @@ -254,25 +253,22 @@ where } fn int_sum(tensor: IntTensor) -> IntTensor { - NumericOps::::sum(tensor) + kernel::sum(tensor) } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - NumericOps::::sum_dim(tensor, dim) - } - fn int_mean(tensor: IntTensor) -> IntTensor { - NumericOps::::mean(tensor) + kernel::sum_dim(tensor, dim) } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - NumericOps::::mean_dim(tensor, dim) + kernel::mean_dim(tensor, dim) } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - NumericOps::::argmax(tensor, dim) + kernel::argmax(tensor, dim) } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - NumericOps::::argmin(tensor, dim) + kernel::argmin(tensor, dim) } } diff --git a/burn-wgpu/src/ops/numeric.rs b/burn-wgpu/src/ops/numeric.rs index 6d43f449b..87f32c21c 100644 --- a/burn-wgpu/src/ops/numeric.rs +++ b/burn-wgpu/src/ops/numeric.rs @@ -1,7 +1,6 @@ use crate::kernel::{ - binary_elemwise_default, binary_elemwise_inplace_default, reduction_args_dim, reduction_dim, - reduction_sum, unary_scalar_default, unary_scalar_inplace_default, ArgsMax, ArgsMin, MeanDim, - SumDim, + binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default, + unary_scalar_inplace_default, }; use crate::pool::get_context; use crate::{ @@ -158,45 +157,4 @@ impl NumericOps { unary_scalar_default::(lhs, rhs) } - - pub fn sum( - tensor: WgpuTensor, - ) -> WgpuTensor { - reduction_sum(tensor) - } - - pub fn sum_dim( - tensor: WgpuTensor, - dim: usize, - ) -> WgpuTensor { - reduction_dim::(tensor, dim) - } - - pub fn mean( - tensor: WgpuTensor, - ) -> WgpuTensor { - let num_elems = tensor.shape.num_elements(); - Self::div_scalar(Self::sum(tensor), (num_elems as f32).elem()) - } - - pub fn mean_dim( - tensor: WgpuTensor, - dim: usize, - ) -> WgpuTensor { - reduction_dim::(tensor, dim) - } - - pub fn argmax( - tensor: WgpuTensor, - dim: usize, - ) -> WgpuTensor { - reduction_args_dim::(tensor, dim) - } - - pub fn argmin( - tensor: WgpuTensor, - dim: usize, - ) -> WgpuTensor { - reduction_args_dim::(tensor, dim) - } } diff --git a/burn-wgpu/src/template/reduction/args.wgsl b/burn-wgpu/src/template/reduction/args.wgsl index e5c8d79b0..d162aacde 100644 --- a/burn-wgpu/src/template/reduction/args.wgsl +++ b/burn-wgpu/src/template/reduction/args.wgsl @@ -10,11 +10,15 @@ var output: array<{{ int }}>; @binding(2) var info: array; +const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; + @compute -@workgroup_size({{ workgroup_size_x }}, 1, 1) +@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) fn main( @builtin(global_invocation_id) global_id: vec3, + @builtin(num_workgroups) num_workgroups: vec3, ) { + let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; let dim: u32 = info[0]; let dim_reduce = info[4u * dim + 1u]; var index_offset: u32 = 0u; @@ -26,7 +30,7 @@ fn main( let stride_output = info[i + dim]; let shape_output = info[i + 3u * dim]; - let num_block = global_id.x / stride_output % shape_output; + let num_block = id / stride_output % shape_output; if i - 1u != dim_reduce { index_offset += num_block * stride_input; @@ -52,5 +56,5 @@ fn main( } } - output[global_id.x] = index; + output[id] = index; } diff --git a/burn-wgpu/src/template/reduction/recursive_sum.wgsl b/burn-wgpu/src/template/reduction/recursive_sum.wgsl index 7293adabd..9495af74a 100644 --- a/burn-wgpu/src/template/reduction/recursive_sum.wgsl +++ b/burn-wgpu/src/template/reduction/recursive_sum.wgsl @@ -1,4 +1,5 @@ -const BLOCK_SIZE = {{ workgroup_size_x }}u; +const WORKGROUP_SIZE = {{ workgroup_size }}u; +const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; @group(0) @binding(0) @@ -8,26 +9,30 @@ var input: array<{{ elem }}>; @binding(1) var output: array<{{ elem }}>; -var data: array<{{ elem }}, BLOCK_SIZE>; +var data: array<{{ elem }}, WORKGROUP_SIZE>; @compute -@workgroup_size({{ workgroup_size_x }}, 1, 1) +@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) fn main( @builtin(global_invocation_id) global_id: vec3, @builtin(local_invocation_id) local_id: vec3, @builtin(workgroup_id) workgroup_id: vec3, @builtin(num_workgroups) num_workgroups: vec3, ) { - data[local_id.x] = input[global_id.x]; + let id_global = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; + let id_local = local_id.y * WORKGROUP_SIZE_X + local_id.x; + + data[id_local] = input[id_global]; workgroupBarrier(); - if local_id.x == 0u { + if id_local == 0u { var sum = {{ elem }}(0); - for (var i: u32 = 0u; i < BLOCK_SIZE; i++) { + for (var i: u32 = 0u; i < WORKGROUP_SIZE; i++) { sum += data[i]; } - output[workgroup_id.x] = sum; + let id_output = workgroup_id.y * num_workgroups.x + workgroup_id.x; + output[id_output] = sum; } } diff --git a/burn-wgpu/src/template/reduction/reduce_dim.wgsl b/burn-wgpu/src/template/reduction/reduce_dim.wgsl index 14d2cc9c1..7c47dde0a 100644 --- a/burn-wgpu/src/template/reduction/reduce_dim.wgsl +++ b/burn-wgpu/src/template/reduction/reduce_dim.wgsl @@ -10,11 +10,15 @@ var output: array<{{ elem }}>; @binding(2) var info: array; +const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; + @compute -@workgroup_size({{ workgroup_size_x }}, 1, 1) +@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) fn main( @builtin(global_invocation_id) global_id: vec3, + @builtin(num_workgroups) num_workgroups: vec3, ) { + let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; let dim: u32 = info[0]; let dim_reduce = info[4u * dim + 1u]; var index_offset: u32 = 0u; @@ -26,7 +30,7 @@ fn main( let stride_output = info[i + dim]; let shape_output = info[i + 3u * dim]; - let num_block = global_id.x / stride_output % shape_output; + let num_block = id / stride_output % shape_output; if i - 1u != dim_reduce { index_offset += num_block * stride_input;