From 06157d3cde5a1d491e007df3a4cfc6129e4d6e6e Mon Sep 17 00:00:00 2001 From: Justin Moore Date: Sat, 2 Sep 2023 10:34:58 -0500 Subject: [PATCH] Perf/ndarray: Optimize `conv2d` operation (#747) --- burn-ndarray/src/ops/adaptive_avgpool.rs | 6 +- burn-ndarray/src/ops/avgpool.rs | 6 +- burn-ndarray/src/ops/base.rs | 24 ++-- burn-ndarray/src/ops/conv.rs | 154 ++++++++++++++++++----- burn-ndarray/src/ops/matmul.rs | 4 +- burn-ndarray/src/ops/maxpool.rs | 8 +- burn-ndarray/src/parallel.rs | 18 ++- 7 files changed, 164 insertions(+), 56 deletions(-) diff --git a/burn-ndarray/src/ops/adaptive_avgpool.rs b/burn-ndarray/src/ops/adaptive_avgpool.rs index 1d63475a6..1e91aa227 100644 --- a/burn-ndarray/src/ops/adaptive_avgpool.rs +++ b/burn-ndarray/src/ops/adaptive_avgpool.rs @@ -1,5 +1,5 @@ use crate::{ - element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef, + element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor, }; use burn_tensor::ElementConversion; @@ -19,7 +19,7 @@ pub(crate) fn adaptive_avg_pool2d( let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { - iter_par!(0, batch_size * channels).for_each(|k| unsafe { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; @@ -61,7 +61,7 @@ pub(crate) fn adaptive_avg_pool2d_backward( let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); run_par!(|| { - iter_par!(0, batch_size * channels).for_each(|k| unsafe { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; diff --git a/burn-ndarray/src/ops/avgpool.rs b/burn-ndarray/src/ops/avgpool.rs index bf79844b3..680c4e117 100644 --- a/burn-ndarray/src/ops/avgpool.rs +++ b/burn-ndarray/src/ops/avgpool.rs @@ -1,5 +1,5 @@ use crate::{ - element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef, + element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor, }; @@ -27,7 +27,7 @@ pub(crate) fn avg_pool2d( let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { - iter_par!(0, batch_size * channels).for_each(|k| unsafe { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; @@ -92,7 +92,7 @@ pub(crate) fn avg_pool2d_backward( let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad); run_par!(|| { - iter_par!(0, batch_size * channels).for_each(|k| unsafe { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; diff --git a/burn-ndarray/src/ops/base.rs b/burn-ndarray/src/ops/base.rs index bcbfdaa04..749ebaf7c 100644 --- a/burn-ndarray/src/ops/base.rs +++ b/burn-ndarray/src/ops/base.rs @@ -392,30 +392,36 @@ where arg(tensor, dim, CmpType::Min) } - pub fn clamp_min(tensor: NdArrayTensor, min: E) -> NdArrayTensor { - let array = tensor.array.mapv(|x| match x < min { + pub fn clamp_min( + mut tensor: NdArrayTensor, + min: E, + ) -> NdArrayTensor { + tensor.array.mapv_inplace(|x| match x < min { true => min, false => x, }); - NdArrayTensor::new(array.into_shared()) + tensor } - pub fn clamp_max(tensor: NdArrayTensor, max: E) -> NdArrayTensor { - let array = tensor.array.mapv(|x| match x > max { + pub fn clamp_max( + mut tensor: NdArrayTensor, + max: E, + ) -> NdArrayTensor { + tensor.array.mapv_inplace(|x| match x > max { true => max, false => x, }); - NdArrayTensor::new(array.into_shared()) + tensor } pub fn clamp( - tensor: NdArrayTensor, + mut tensor: NdArrayTensor, min: E, max: E, ) -> NdArrayTensor { - let array = tensor.array.mapv(|x| match x < min { + tensor.array.mapv_inplace(|x| match x < min { true => min, false => match x > max { true => max, @@ -423,7 +429,7 @@ where }, }); - NdArrayTensor::new(array.into_shared()) + tensor } } diff --git a/burn-ndarray/src/ops/conv.rs b/burn-ndarray/src/ops/conv.rs index 8f43e026e..1d4fcc259 100644 --- a/burn-ndarray/src/ops/conv.rs +++ b/burn-ndarray/src/ops/conv.rs @@ -2,20 +2,58 @@ use burn_tensor::{ ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions}, ElementConversion, }; -use ndarray::{Array4, Dim}; +use ndarray::{s, Array3, Array4, ArrayView2, ArrayViewMut2, Axis, Dim}; use crate::{ - element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par, - sharing::UnsafeSharedRef, tensor::NdArrayTensor, + element::FloatNdArrayElement, iter_par, iter_range_par, ops::padding::apply_padding_4d, + run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor, }; +#[inline(always)] +fn conv2d_mad_inner( + mut output: ArrayViewMut2, + x: ArrayView2, + k: E, + k_xy: (usize, usize), + out_xy: (usize, usize), + stride: (usize, usize), + dilation: (usize, usize), +) { + let (kh, kw) = k_xy; + let (out_width, out_height) = out_xy; + let (stride_width, stride_height) = stride; + let (dilation_width, dilation_height) = dilation; + + for oh in 0..out_height { + // Construct a sub-slice view of the input row. + // This is done upfront so that rustc does not have to emit bounds checks + // in the hot loop below. + let ir = x + .row(oh * stride_height + kh * dilation_height) + .to_slice() + .unwrap(); + + // Ditto. Construct a sub-slice view of the output row, and explicitly specify + // the bounds upfront as 0..out_width so that rustc can make the assumption + // that all accesses are in-bounds in the below loop. + let mut or = output.row_mut(oh); + let or = &mut or.as_slice_mut().unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + let iw = (ow * stride_width) + (kw * dilation_width); + or[ow] += ir[iw] * k; + } + } +} + pub(crate) fn conv2d( x: NdArrayTensor, weight: NdArrayTensor, bias: Option>, options: ConvOptions<2>, ) -> NdArrayTensor { - let [dilatation_height, dilatation_width] = options.dilation; + let [dilation_height, dilation_width] = options.dilation; let [padding_height, padding_width] = options.padding; let [stride_height, stride_width] = options.stride; let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; @@ -25,59 +63,107 @@ pub(crate) fn conv2d( kernel_height, stride_height, padding_height, - dilatation_height, + dilation_height, in_height, ); let out_width = calculate_conv_output_size( kernel_width, stride_width, padding_width, - dilatation_width, + dilation_width, in_width, ); let x = apply_padding_4d(x, options.padding, 0i32.elem()).array; - let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width])); + // Convert inputs from dynamic indexes to static to improve perf. + let x = x.into_dimensionality::().unwrap(); + let weights = weight.array.into_dimensionality::().unwrap(); - let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width])); run_par!(|| { - iter_par!(0, batch_size * out_channels).for_each(|k| unsafe { - let b = k / out_channels; - let oc = k % out_channels; - let g = k % options.groups; + iter_par!(output.axis_iter_mut(Axis(0))) + .enumerate() + .for_each( + #[inline(never)] + |(k, mut output)| { + let b = k / out_channels; + let oc = k % out_channels; + let g = k % options.groups; - let output = unsafe_shared_out.get(); + for ic in (in_channels * g)..(in_channels * (g + 1)) { + let weight_ic = ic - (g * in_channels); - for ic in (in_channels * g)..(in_channels * (g + 1)) { - for kh in 0..kernel_height { - for kw in 0..kernel_width { - for oh in 0..out_height { - for ow in 0..out_width { - let ih = oh * stride_height + kh * dilatation_height; - let iw = ow * stride_width + kw * dilatation_width; + let x = x.slice(s![b, ic, .., ..]); + let k = weights.slice(s![oc, weight_ic, .., ..]); - let weight_ic = ic - (g * in_channels); - output[[b, oc, oh, ow]] += - x[[b, ic, ih, iw]] * weight.array[[oc, weight_ic, kh, kw]]; + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let k = k[[kh, kw]]; + + // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization + // in the case that the stride/dilation is 1. + #[allow(clippy::if_same_then_else)] + if (1, 1, 1, 1) + == ( + stride_width, + stride_height, + dilation_width, + dilation_height, + ) + { + conv2d_mad_inner( + output.view_mut(), + x.view(), + k, + (kh, kw), + (out_width, out_height), + (stride_width, stride_height), + (dilation_width, dilation_height), + ); + } else { + conv2d_mad_inner( + output.view_mut(), + x.view(), + k, + (kh, kw), + (out_width, out_height), + (stride_width, stride_height), + (dilation_width, dilation_height), + ); + } } } } - } - } - if let Some(bias) = &bias { - for oh in 0..out_height { - for ow in 0..out_width { - output[[b, oc, oh, ow]] += bias.array[oc]; + if let Some(bias) = &bias { + let bias = bias.array[oc]; + + for oh in 0..out_height { + // Get a mutable slice reference to the row we're looping over. + // We explicitly define the bounds to 0..out_width so that rustc can make + // the assumption that all accesses are in-bounds. + let mut or = output.row_mut(oh); + let or = &mut or.as_slice_mut().unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + or[ow] += bias; + } + } } - } - } - }); + }, + ); }); - NdArrayTensor::new(output.into_dyn().into_shared()) + let output = output + .into_shape([batch_size, out_channels, out_height, out_width]) + .unwrap() + .into_dyn() + .into_shared(); + + NdArrayTensor::new(output) } pub(crate) fn conv_transpose2d( @@ -114,7 +200,7 @@ pub(crate) fn conv_transpose2d( let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { - iter_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { + iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { let b = k / (out_channels * options.groups); let oc = k % out_channels; let g = k % options.groups; diff --git a/burn-ndarray/src/ops/matmul.rs b/burn-ndarray/src/ops/matmul.rs index 9a1c93f9c..0dd2d8d02 100644 --- a/burn-ndarray/src/ops/matmul.rs +++ b/burn-ndarray/src/ops/matmul.rs @@ -1,5 +1,5 @@ use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend}; -use crate::{iter_par, run_par, UnsafeSharedRef}; +use crate::{iter_range_par, run_par, UnsafeSharedRef}; use burn_tensor::ElementConversion; use burn_tensor::{ops::TensorOps, Shape}; use ndarray::s; @@ -58,7 +58,7 @@ fn general_matmul( let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap(); let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap(); - iter_par!(0, batch_size).for_each(|b| { + iter_range_par!(0, batch_size).for_each(|b| { let lhs_slice = match batch_size_lhs == 1 { true => lhs_array.slice(s!(0, .., ..)), false => lhs_array.slice(s!(b, .., ..)), diff --git a/burn-ndarray/src/ops/maxpool.rs b/burn-ndarray/src/ops/maxpool.rs index 32973b359..948c94293 100644 --- a/burn-ndarray/src/ops/maxpool.rs +++ b/burn-ndarray/src/ops/maxpool.rs @@ -1,5 +1,5 @@ use crate::{ - element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par, + element::FloatNdArrayElement, iter_range_par, ops::padding::apply_padding_4d, run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor, }; @@ -33,7 +33,7 @@ pub(crate) fn max_pool2d( let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { - iter_par!(0, batch_size * channels).for_each(|k| unsafe { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; @@ -96,7 +96,7 @@ pub(crate) fn max_pool2d_with_indices( let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); run_par!(|| { - iter_par!(0, batch_size * channels).for_each(|k| unsafe { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; @@ -159,7 +159,7 @@ pub(crate) fn max_pool2d_backward( let unsafe_shared_out = UnsafeSharedRef::new(&mut output); run_par!(|| { - iter_par!(0, batch_size * channels).for_each(|k| unsafe { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { let b = k / channels; let c = k % channels; diff --git a/burn-ndarray/src/parallel.rs b/burn-ndarray/src/parallel.rs index df81bb238..8229bfb39 100644 --- a/burn-ndarray/src/parallel.rs +++ b/burn-ndarray/src/parallel.rs @@ -18,9 +18,25 @@ macro_rules! run_par { }}; } -/// Macro for iterating over a range in parallel. +/// Macro for iterating in parallel. #[macro_export(local_inner_macros)] macro_rules! iter_par { + ( + $iter:expr + ) => {{ + #[cfg(feature = "std")] + let output = $iter.into_par_iter(); + + #[cfg(not(feature = "std"))] + let output = $iter; + + output + }}; +} + +/// Macro for iterating over a range in parallel. +#[macro_export(local_inner_macros)] +macro_rules! iter_range_par { ( $start:expr, $end:expr ) => {{