Perf/ndarray: Optimize `conv2d` operation (#747)

This commit is contained in:
Justin Moore 2023-09-02 10:34:58 -05:00 committed by GitHub
parent a47d23c3dd
commit 06157d3cde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 164 additions and 56 deletions

View File

@ -1,5 +1,5 @@
use crate::{
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef,
element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
tensor::NdArrayTensor,
};
use burn_tensor::ElementConversion;
@ -19,7 +19,7 @@ pub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
@ -61,7 +61,7 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;

View File

@ -1,5 +1,5 @@
use crate::{
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef,
element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
tensor::NdArrayTensor,
};
@ -27,7 +27,7 @@ pub(crate) fn avg_pool2d<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
@ -92,7 +92,7 @@ pub(crate) fn avg_pool2d_backward<E: FloatNdArrayElement>(
let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;

View File

@ -392,30 +392,36 @@ where
arg(tensor, dim, CmpType::Min)
}
pub fn clamp_min<const D: usize>(tensor: NdArrayTensor<E, D>, min: E) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv(|x| match x < min {
pub fn clamp_min<const D: usize>(
mut tensor: NdArrayTensor<E, D>,
min: E,
) -> NdArrayTensor<E, D> {
tensor.array.mapv_inplace(|x| match x < min {
true => min,
false => x,
});
NdArrayTensor::new(array.into_shared())
tensor
}
pub fn clamp_max<const D: usize>(tensor: NdArrayTensor<E, D>, max: E) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv(|x| match x > max {
pub fn clamp_max<const D: usize>(
mut tensor: NdArrayTensor<E, D>,
max: E,
) -> NdArrayTensor<E, D> {
tensor.array.mapv_inplace(|x| match x > max {
true => max,
false => x,
});
NdArrayTensor::new(array.into_shared())
tensor
}
pub fn clamp<const D: usize>(
tensor: NdArrayTensor<E, D>,
mut tensor: NdArrayTensor<E, D>,
min: E,
max: E,
) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv(|x| match x < min {
tensor.array.mapv_inplace(|x| match x < min {
true => min,
false => match x > max {
true => max,
@ -423,7 +429,7 @@ where
},
});
NdArrayTensor::new(array.into_shared())
tensor
}
}

View File

@ -2,20 +2,58 @@ use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions},
ElementConversion,
};
use ndarray::{Array4, Dim};
use ndarray::{s, Array3, Array4, ArrayView2, ArrayViewMut2, Axis, Dim};
use crate::{
element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par,
sharing::UnsafeSharedRef, tensor::NdArrayTensor,
element::FloatNdArrayElement, iter_par, iter_range_par, ops::padding::apply_padding_4d,
run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor,
};
#[inline(always)]
fn conv2d_mad_inner<E: FloatNdArrayElement>(
mut output: ArrayViewMut2<E>,
x: ArrayView2<E>,
k: E,
k_xy: (usize, usize),
out_xy: (usize, usize),
stride: (usize, usize),
dilation: (usize, usize),
) {
let (kh, kw) = k_xy;
let (out_width, out_height) = out_xy;
let (stride_width, stride_height) = stride;
let (dilation_width, dilation_height) = dilation;
for oh in 0..out_height {
// Construct a sub-slice view of the input row.
// This is done upfront so that rustc does not have to emit bounds checks
// in the hot loop below.
let ir = x
.row(oh * stride_height + kh * dilation_height)
.to_slice()
.unwrap();
// Ditto. Construct a sub-slice view of the output row, and explicitly specify
// the bounds upfront as 0..out_width so that rustc can make the assumption
// that all accesses are in-bounds in the below loop.
let mut or = output.row_mut(oh);
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
#[allow(clippy::needless_range_loop)]
for ow in 0..out_width {
let iw = (ow * stride_width) + (kw * dilation_width);
or[ow] += ir[iw] * k;
}
}
}
pub(crate) fn conv2d<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>,
options: ConvOptions<2>,
) -> NdArrayTensor<E, 4> {
let [dilatation_height, dilatation_width] = options.dilation;
let [dilation_height, dilation_width] = options.dilation;
let [padding_height, padding_width] = options.padding;
let [stride_height, stride_width] = options.stride;
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
@ -25,59 +63,107 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
kernel_height,
stride_height,
padding_height,
dilatation_height,
dilation_height,
in_height,
);
let out_width = calculate_conv_output_size(
kernel_width,
stride_width,
padding_width,
dilatation_width,
dilation_width,
in_width,
);
let x = apply_padding_4d(x, options.padding, 0i32.elem()).array;
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width]));
// Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
let weights = weight.array.into_dimensionality::<ndarray::Ix4>().unwrap();
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width]));
run_par!(|| {
iter_par!(0, batch_size * out_channels).for_each(|k| unsafe {
let b = k / out_channels;
let oc = k % out_channels;
let g = k % options.groups;
iter_par!(output.axis_iter_mut(Axis(0)))
.enumerate()
.for_each(
#[inline(never)]
|(k, mut output)| {
let b = k / out_channels;
let oc = k % out_channels;
let g = k % options.groups;
let output = unsafe_shared_out.get();
for ic in (in_channels * g)..(in_channels * (g + 1)) {
let weight_ic = ic - (g * in_channels);
for ic in (in_channels * g)..(in_channels * (g + 1)) {
for kh in 0..kernel_height {
for kw in 0..kernel_width {
for oh in 0..out_height {
for ow in 0..out_width {
let ih = oh * stride_height + kh * dilatation_height;
let iw = ow * stride_width + kw * dilatation_width;
let x = x.slice(s![b, ic, .., ..]);
let k = weights.slice(s![oc, weight_ic, .., ..]);
let weight_ic = ic - (g * in_channels);
output[[b, oc, oh, ow]] +=
x[[b, ic, ih, iw]] * weight.array[[oc, weight_ic, kh, kw]];
for kh in 0..kernel_height {
for kw in 0..kernel_width {
let k = k[[kh, kw]];
// NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization
// in the case that the stride/dilation is 1.
#[allow(clippy::if_same_then_else)]
if (1, 1, 1, 1)
== (
stride_width,
stride_height,
dilation_width,
dilation_height,
)
{
conv2d_mad_inner(
output.view_mut(),
x.view(),
k,
(kh, kw),
(out_width, out_height),
(stride_width, stride_height),
(dilation_width, dilation_height),
);
} else {
conv2d_mad_inner(
output.view_mut(),
x.view(),
k,
(kh, kw),
(out_width, out_height),
(stride_width, stride_height),
(dilation_width, dilation_height),
);
}
}
}
}
}
}
if let Some(bias) = &bias {
for oh in 0..out_height {
for ow in 0..out_width {
output[[b, oc, oh, ow]] += bias.array[oc];
if let Some(bias) = &bias {
let bias = bias.array[oc];
for oh in 0..out_height {
// Get a mutable slice reference to the row we're looping over.
// We explicitly define the bounds to 0..out_width so that rustc can make
// the assumption that all accesses are in-bounds.
let mut or = output.row_mut(oh);
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
#[allow(clippy::needless_range_loop)]
for ow in 0..out_width {
or[ow] += bias;
}
}
}
}
}
});
},
);
});
NdArrayTensor::new(output.into_dyn().into_shared())
let output = output
.into_shape([batch_size, out_channels, out_height, out_width])
.unwrap()
.into_dyn()
.into_shared();
NdArrayTensor::new(output)
}
pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
@ -114,7 +200,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
let b = k / (out_channels * options.groups);
let oc = k % out_channels;
let g = k % options.groups;

View File

@ -1,5 +1,5 @@
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use crate::{iter_par, run_par, UnsafeSharedRef};
use crate::{iter_range_par, run_par, UnsafeSharedRef};
use burn_tensor::ElementConversion;
use burn_tensor::{ops::TensorOps, Shape};
use ndarray::s;
@ -58,7 +58,7 @@ fn general_matmul<E: FloatNdArrayElement>(
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap();
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap();
iter_par!(0, batch_size).for_each(|b| {
iter_range_par!(0, batch_size).for_each(|b| {
let lhs_slice = match batch_size_lhs == 1 {
true => lhs_array.slice(s!(0, .., ..)),
false => lhs_array.slice(s!(b, .., ..)),

View File

@ -1,5 +1,5 @@
use crate::{
element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par,
element::FloatNdArrayElement, iter_range_par, ops::padding::apply_padding_4d, run_par,
sharing::UnsafeSharedRef, tensor::NdArrayTensor,
};
@ -33,7 +33,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
@ -96,7 +96,7 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
@ -159,7 +159,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;

View File

@ -18,9 +18,25 @@ macro_rules! run_par {
}};
}
/// Macro for iterating over a range in parallel.
/// Macro for iterating in parallel.
#[macro_export(local_inner_macros)]
macro_rules! iter_par {
(
$iter:expr
) => {{
#[cfg(feature = "std")]
let output = $iter.into_par_iter();
#[cfg(not(feature = "std"))]
let output = $iter;
output
}};
}
/// Macro for iterating over a range in parallel.
#[macro_export(local_inner_macros)]
macro_rules! iter_range_par {
(
$start:expr, $end:expr
) => {{