Perf/ndarray: Optimize `conv2d` operation (#747)

This commit is contained in:
Justin Moore 2023-09-02 10:34:58 -05:00 committed by GitHub
parent a47d23c3dd
commit 06157d3cde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 164 additions and 56 deletions

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
tensor::NdArrayTensor, tensor::NdArrayTensor,
}; };
use burn_tensor::ElementConversion; use burn_tensor::ElementConversion;
@ -19,7 +19,7 @@ pub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| { run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels; let b = k / channels;
let c = k % channels; let c = k % channels;
@ -61,7 +61,7 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
run_par!(|| { run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels; let b = k / channels;
let c = k % channels; let c = k % channels;

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
tensor::NdArrayTensor, tensor::NdArrayTensor,
}; };
@ -27,7 +27,7 @@ pub(crate) fn avg_pool2d<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| { run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels; let b = k / channels;
let c = k % channels; let c = k % channels;
@ -92,7 +92,7 @@ pub(crate) fn avg_pool2d_backward<E: FloatNdArrayElement>(
let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad); let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad);
run_par!(|| { run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels; let b = k / channels;
let c = k % channels; let c = k % channels;

View File

@ -392,30 +392,36 @@ where
arg(tensor, dim, CmpType::Min) arg(tensor, dim, CmpType::Min)
} }
pub fn clamp_min<const D: usize>(tensor: NdArrayTensor<E, D>, min: E) -> NdArrayTensor<E, D> { pub fn clamp_min<const D: usize>(
let array = tensor.array.mapv(|x| match x < min { mut tensor: NdArrayTensor<E, D>,
min: E,
) -> NdArrayTensor<E, D> {
tensor.array.mapv_inplace(|x| match x < min {
true => min, true => min,
false => x, false => x,
}); });
NdArrayTensor::new(array.into_shared()) tensor
} }
pub fn clamp_max<const D: usize>(tensor: NdArrayTensor<E, D>, max: E) -> NdArrayTensor<E, D> { pub fn clamp_max<const D: usize>(
let array = tensor.array.mapv(|x| match x > max { mut tensor: NdArrayTensor<E, D>,
max: E,
) -> NdArrayTensor<E, D> {
tensor.array.mapv_inplace(|x| match x > max {
true => max, true => max,
false => x, false => x,
}); });
NdArrayTensor::new(array.into_shared()) tensor
} }
pub fn clamp<const D: usize>( pub fn clamp<const D: usize>(
tensor: NdArrayTensor<E, D>, mut tensor: NdArrayTensor<E, D>,
min: E, min: E,
max: E, max: E,
) -> NdArrayTensor<E, D> { ) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv(|x| match x < min { tensor.array.mapv_inplace(|x| match x < min {
true => min, true => min,
false => match x > max { false => match x > max {
true => max, true => max,
@ -423,7 +429,7 @@ where
}, },
}); });
NdArrayTensor::new(array.into_shared()) tensor
} }
} }

View File

@ -2,20 +2,58 @@ use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions}, ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions},
ElementConversion, ElementConversion,
}; };
use ndarray::{Array4, Dim}; use ndarray::{s, Array3, Array4, ArrayView2, ArrayViewMut2, Axis, Dim};
use crate::{ use crate::{
element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par, element::FloatNdArrayElement, iter_par, iter_range_par, ops::padding::apply_padding_4d,
sharing::UnsafeSharedRef, tensor::NdArrayTensor, run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor,
}; };
#[inline(always)]
fn conv2d_mad_inner<E: FloatNdArrayElement>(
mut output: ArrayViewMut2<E>,
x: ArrayView2<E>,
k: E,
k_xy: (usize, usize),
out_xy: (usize, usize),
stride: (usize, usize),
dilation: (usize, usize),
) {
let (kh, kw) = k_xy;
let (out_width, out_height) = out_xy;
let (stride_width, stride_height) = stride;
let (dilation_width, dilation_height) = dilation;
for oh in 0..out_height {
// Construct a sub-slice view of the input row.
// This is done upfront so that rustc does not have to emit bounds checks
// in the hot loop below.
let ir = x
.row(oh * stride_height + kh * dilation_height)
.to_slice()
.unwrap();
// Ditto. Construct a sub-slice view of the output row, and explicitly specify
// the bounds upfront as 0..out_width so that rustc can make the assumption
// that all accesses are in-bounds in the below loop.
let mut or = output.row_mut(oh);
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
#[allow(clippy::needless_range_loop)]
for ow in 0..out_width {
let iw = (ow * stride_width) + (kw * dilation_width);
or[ow] += ir[iw] * k;
}
}
}
pub(crate) fn conv2d<E: FloatNdArrayElement>( pub(crate) fn conv2d<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>, x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>, weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>, bias: Option<NdArrayTensor<E, 1>>,
options: ConvOptions<2>, options: ConvOptions<2>,
) -> NdArrayTensor<E, 4> { ) -> NdArrayTensor<E, 4> {
let [dilatation_height, dilatation_width] = options.dilation; let [dilation_height, dilation_width] = options.dilation;
let [padding_height, padding_width] = options.padding; let [padding_height, padding_width] = options.padding;
let [stride_height, stride_width] = options.stride; let [stride_height, stride_width] = options.stride;
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
@ -25,59 +63,107 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
kernel_height, kernel_height,
stride_height, stride_height,
padding_height, padding_height,
dilatation_height, dilation_height,
in_height, in_height,
); );
let out_width = calculate_conv_output_size( let out_width = calculate_conv_output_size(
kernel_width, kernel_width,
stride_width, stride_width,
padding_width, padding_width,
dilatation_width, dilation_width,
in_width, in_width,
); );
let x = apply_padding_4d(x, options.padding, 0i32.elem()).array; let x = apply_padding_4d(x, options.padding, 0i32.elem()).array;
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width])); // Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
let weights = weight.array.into_dimensionality::<ndarray::Ix4>().unwrap();
let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width]));
run_par!(|| { run_par!(|| {
iter_par!(0, batch_size * out_channels).for_each(|k| unsafe { iter_par!(output.axis_iter_mut(Axis(0)))
.enumerate()
.for_each(
#[inline(never)]
|(k, mut output)| {
let b = k / out_channels; let b = k / out_channels;
let oc = k % out_channels; let oc = k % out_channels;
let g = k % options.groups; let g = k % options.groups;
let output = unsafe_shared_out.get();
for ic in (in_channels * g)..(in_channels * (g + 1)) { for ic in (in_channels * g)..(in_channels * (g + 1)) {
let weight_ic = ic - (g * in_channels);
let x = x.slice(s![b, ic, .., ..]);
let k = weights.slice(s![oc, weight_ic, .., ..]);
for kh in 0..kernel_height { for kh in 0..kernel_height {
for kw in 0..kernel_width { for kw in 0..kernel_width {
for oh in 0..out_height { let k = k[[kh, kw]];
for ow in 0..out_width {
let ih = oh * stride_height + kh * dilatation_height;
let iw = ow * stride_width + kw * dilatation_width;
let weight_ic = ic - (g * in_channels); // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization
output[[b, oc, oh, ow]] += // in the case that the stride/dilation is 1.
x[[b, ic, ih, iw]] * weight.array[[oc, weight_ic, kh, kw]]; #[allow(clippy::if_same_then_else)]
} if (1, 1, 1, 1)
== (
stride_width,
stride_height,
dilation_width,
dilation_height,
)
{
conv2d_mad_inner(
output.view_mut(),
x.view(),
k,
(kh, kw),
(out_width, out_height),
(stride_width, stride_height),
(dilation_width, dilation_height),
);
} else {
conv2d_mad_inner(
output.view_mut(),
x.view(),
k,
(kh, kw),
(out_width, out_height),
(stride_width, stride_height),
(dilation_width, dilation_height),
);
} }
} }
} }
} }
if let Some(bias) = &bias { if let Some(bias) = &bias {
let bias = bias.array[oc];
for oh in 0..out_height { for oh in 0..out_height {
// Get a mutable slice reference to the row we're looping over.
// We explicitly define the bounds to 0..out_width so that rustc can make
// the assumption that all accesses are in-bounds.
let mut or = output.row_mut(oh);
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
#[allow(clippy::needless_range_loop)]
for ow in 0..out_width { for ow in 0..out_width {
output[[b, oc, oh, ow]] += bias.array[oc]; or[ow] += bias;
} }
} }
} }
}); },
);
}); });
NdArrayTensor::new(output.into_dyn().into_shared()) let output = output
.into_shape([batch_size, out_channels, out_height, out_width])
.unwrap()
.into_dyn()
.into_shared();
NdArrayTensor::new(output)
} }
pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>( pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
@ -114,7 +200,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| { run_par!(|| {
iter_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
let b = k / (out_channels * options.groups); let b = k / (out_channels * options.groups);
let oc = k % out_channels; let oc = k % out_channels;
let g = k % options.groups; let g = k % options.groups;

View File

@ -1,5 +1,5 @@
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend}; use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use crate::{iter_par, run_par, UnsafeSharedRef}; use crate::{iter_range_par, run_par, UnsafeSharedRef};
use burn_tensor::ElementConversion; use burn_tensor::ElementConversion;
use burn_tensor::{ops::TensorOps, Shape}; use burn_tensor::{ops::TensorOps, Shape};
use ndarray::s; use ndarray::s;
@ -58,7 +58,7 @@ fn general_matmul<E: FloatNdArrayElement>(
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap(); let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap();
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap(); let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap();
iter_par!(0, batch_size).for_each(|b| { iter_range_par!(0, batch_size).for_each(|b| {
let lhs_slice = match batch_size_lhs == 1 { let lhs_slice = match batch_size_lhs == 1 {
true => lhs_array.slice(s!(0, .., ..)), true => lhs_array.slice(s!(0, .., ..)),
false => lhs_array.slice(s!(b, .., ..)), false => lhs_array.slice(s!(b, .., ..)),

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par, element::FloatNdArrayElement, iter_range_par, ops::padding::apply_padding_4d, run_par,
sharing::UnsafeSharedRef, tensor::NdArrayTensor, sharing::UnsafeSharedRef, tensor::NdArrayTensor,
}; };
@ -33,7 +33,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| { run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels; let b = k / channels;
let c = k % channels; let c = k % channels;
@ -96,7 +96,7 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
run_par!(|| { run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels; let b = k / channels;
let c = k % channels; let c = k % channels;
@ -159,7 +159,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output); let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| { run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe { iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels; let b = k / channels;
let c = k % channels; let c = k % channels;

View File

@ -18,9 +18,25 @@ macro_rules! run_par {
}}; }};
} }
/// Macro for iterating over a range in parallel. /// Macro for iterating in parallel.
#[macro_export(local_inner_macros)] #[macro_export(local_inner_macros)]
macro_rules! iter_par { macro_rules! iter_par {
(
$iter:expr
) => {{
#[cfg(feature = "std")]
let output = $iter.into_par_iter();
#[cfg(not(feature = "std"))]
let output = $iter;
output
}};
}
/// Macro for iterating over a range in parallel.
#[macro_export(local_inner_macros)]
macro_rules! iter_range_par {
( (
$start:expr, $end:expr $start:expr, $end:expr
) => {{ ) => {{