mirror of https://github.com/tracel-ai/burn.git
Perf/ndarray: Optimize `conv2d` operation (#747)
This commit is contained in:
parent
a47d23c3dd
commit
06157d3cde
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef,
|
||||
element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
|
||||
tensor::NdArrayTensor,
|
||||
};
|
||||
use burn_tensor::ElementConversion;
|
||||
|
@ -19,7 +19,7 @@ pub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(
|
|||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
|
@ -61,7 +61,7 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(
|
|||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef,
|
||||
element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
|
||||
tensor::NdArrayTensor,
|
||||
};
|
||||
|
||||
|
@ -27,7 +27,7 @@ pub(crate) fn avg_pool2d<E: FloatNdArrayElement>(
|
|||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
|
@ -92,7 +92,7 @@ pub(crate) fn avg_pool2d_backward<E: FloatNdArrayElement>(
|
|||
let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
|
|
|
@ -392,30 +392,36 @@ where
|
|||
arg(tensor, dim, CmpType::Min)
|
||||
}
|
||||
|
||||
pub fn clamp_min<const D: usize>(tensor: NdArrayTensor<E, D>, min: E) -> NdArrayTensor<E, D> {
|
||||
let array = tensor.array.mapv(|x| match x < min {
|
||||
pub fn clamp_min<const D: usize>(
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
min: E,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
tensor.array.mapv_inplace(|x| match x < min {
|
||||
true => min,
|
||||
false => x,
|
||||
});
|
||||
|
||||
NdArrayTensor::new(array.into_shared())
|
||||
tensor
|
||||
}
|
||||
|
||||
pub fn clamp_max<const D: usize>(tensor: NdArrayTensor<E, D>, max: E) -> NdArrayTensor<E, D> {
|
||||
let array = tensor.array.mapv(|x| match x > max {
|
||||
pub fn clamp_max<const D: usize>(
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
max: E,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
tensor.array.mapv_inplace(|x| match x > max {
|
||||
true => max,
|
||||
false => x,
|
||||
});
|
||||
|
||||
NdArrayTensor::new(array.into_shared())
|
||||
tensor
|
||||
}
|
||||
|
||||
pub fn clamp<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
min: E,
|
||||
max: E,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
let array = tensor.array.mapv(|x| match x < min {
|
||||
tensor.array.mapv_inplace(|x| match x < min {
|
||||
true => min,
|
||||
false => match x > max {
|
||||
true => max,
|
||||
|
@ -423,7 +429,7 @@ where
|
|||
},
|
||||
});
|
||||
|
||||
NdArrayTensor::new(array.into_shared())
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,20 +2,58 @@ use burn_tensor::{
|
|||
ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions},
|
||||
ElementConversion,
|
||||
};
|
||||
use ndarray::{Array4, Dim};
|
||||
use ndarray::{s, Array3, Array4, ArrayView2, ArrayViewMut2, Axis, Dim};
|
||||
|
||||
use crate::{
|
||||
element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par,
|
||||
sharing::UnsafeSharedRef, tensor::NdArrayTensor,
|
||||
element::FloatNdArrayElement, iter_par, iter_range_par, ops::padding::apply_padding_4d,
|
||||
run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor,
|
||||
};
|
||||
|
||||
#[inline(always)]
|
||||
fn conv2d_mad_inner<E: FloatNdArrayElement>(
|
||||
mut output: ArrayViewMut2<E>,
|
||||
x: ArrayView2<E>,
|
||||
k: E,
|
||||
k_xy: (usize, usize),
|
||||
out_xy: (usize, usize),
|
||||
stride: (usize, usize),
|
||||
dilation: (usize, usize),
|
||||
) {
|
||||
let (kh, kw) = k_xy;
|
||||
let (out_width, out_height) = out_xy;
|
||||
let (stride_width, stride_height) = stride;
|
||||
let (dilation_width, dilation_height) = dilation;
|
||||
|
||||
for oh in 0..out_height {
|
||||
// Construct a sub-slice view of the input row.
|
||||
// This is done upfront so that rustc does not have to emit bounds checks
|
||||
// in the hot loop below.
|
||||
let ir = x
|
||||
.row(oh * stride_height + kh * dilation_height)
|
||||
.to_slice()
|
||||
.unwrap();
|
||||
|
||||
// Ditto. Construct a sub-slice view of the output row, and explicitly specify
|
||||
// the bounds upfront as 0..out_width so that rustc can make the assumption
|
||||
// that all accesses are in-bounds in the below loop.
|
||||
let mut or = output.row_mut(oh);
|
||||
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for ow in 0..out_width {
|
||||
let iw = (ow * stride_width) + (kw * dilation_width);
|
||||
or[ow] += ir[iw] * k;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [dilatation_height, dilatation_width] = options.dilation;
|
||||
let [dilation_height, dilation_width] = options.dilation;
|
||||
let [padding_height, padding_width] = options.padding;
|
||||
let [stride_height, stride_width] = options.stride;
|
||||
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
|
||||
|
@ -25,59 +63,107 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
|||
kernel_height,
|
||||
stride_height,
|
||||
padding_height,
|
||||
dilatation_height,
|
||||
dilation_height,
|
||||
in_height,
|
||||
);
|
||||
let out_width = calculate_conv_output_size(
|
||||
kernel_width,
|
||||
stride_width,
|
||||
padding_width,
|
||||
dilatation_width,
|
||||
dilation_width,
|
||||
in_width,
|
||||
);
|
||||
|
||||
let x = apply_padding_4d(x, options.padding, 0i32.elem()).array;
|
||||
|
||||
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width]));
|
||||
// Convert inputs from dynamic indexes to static to improve perf.
|
||||
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
let weights = weight.array.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width]));
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * out_channels).for_each(|k| unsafe {
|
||||
let b = k / out_channels;
|
||||
let oc = k % out_channels;
|
||||
let g = k % options.groups;
|
||||
iter_par!(output.axis_iter_mut(Axis(0)))
|
||||
.enumerate()
|
||||
.for_each(
|
||||
#[inline(never)]
|
||||
|(k, mut output)| {
|
||||
let b = k / out_channels;
|
||||
let oc = k % out_channels;
|
||||
let g = k % options.groups;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
for ic in (in_channels * g)..(in_channels * (g + 1)) {
|
||||
let weight_ic = ic - (g * in_channels);
|
||||
|
||||
for ic in (in_channels * g)..(in_channels * (g + 1)) {
|
||||
for kh in 0..kernel_height {
|
||||
for kw in 0..kernel_width {
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
let ih = oh * stride_height + kh * dilatation_height;
|
||||
let iw = ow * stride_width + kw * dilatation_width;
|
||||
let x = x.slice(s![b, ic, .., ..]);
|
||||
let k = weights.slice(s![oc, weight_ic, .., ..]);
|
||||
|
||||
let weight_ic = ic - (g * in_channels);
|
||||
output[[b, oc, oh, ow]] +=
|
||||
x[[b, ic, ih, iw]] * weight.array[[oc, weight_ic, kh, kw]];
|
||||
for kh in 0..kernel_height {
|
||||
for kw in 0..kernel_width {
|
||||
let k = k[[kh, kw]];
|
||||
|
||||
// NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization
|
||||
// in the case that the stride/dilation is 1.
|
||||
#[allow(clippy::if_same_then_else)]
|
||||
if (1, 1, 1, 1)
|
||||
== (
|
||||
stride_width,
|
||||
stride_height,
|
||||
dilation_width,
|
||||
dilation_height,
|
||||
)
|
||||
{
|
||||
conv2d_mad_inner(
|
||||
output.view_mut(),
|
||||
x.view(),
|
||||
k,
|
||||
(kh, kw),
|
||||
(out_width, out_height),
|
||||
(stride_width, stride_height),
|
||||
(dilation_width, dilation_height),
|
||||
);
|
||||
} else {
|
||||
conv2d_mad_inner(
|
||||
output.view_mut(),
|
||||
x.view(),
|
||||
k,
|
||||
(kh, kw),
|
||||
(out_width, out_height),
|
||||
(stride_width, stride_height),
|
||||
(dilation_width, dilation_height),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(bias) = &bias {
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
output[[b, oc, oh, ow]] += bias.array[oc];
|
||||
if let Some(bias) = &bias {
|
||||
let bias = bias.array[oc];
|
||||
|
||||
for oh in 0..out_height {
|
||||
// Get a mutable slice reference to the row we're looping over.
|
||||
// We explicitly define the bounds to 0..out_width so that rustc can make
|
||||
// the assumption that all accesses are in-bounds.
|
||||
let mut or = output.row_mut(oh);
|
||||
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for ow in 0..out_width {
|
||||
or[ow] += bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
NdArrayTensor::new(output.into_dyn().into_shared())
|
||||
let output = output
|
||||
.into_shape([batch_size, out_channels, out_height, out_width])
|
||||
.unwrap()
|
||||
.into_dyn()
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor::new(output)
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
||||
|
@ -114,7 +200,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
|||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
|
||||
iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
|
||||
let b = k / (out_channels * options.groups);
|
||||
let oc = k % out_channels;
|
||||
let g = k % options.groups;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||
use crate::{iter_par, run_par, UnsafeSharedRef};
|
||||
use crate::{iter_range_par, run_par, UnsafeSharedRef};
|
||||
use burn_tensor::ElementConversion;
|
||||
use burn_tensor::{ops::TensorOps, Shape};
|
||||
use ndarray::s;
|
||||
|
@ -58,7 +58,7 @@ fn general_matmul<E: FloatNdArrayElement>(
|
|||
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap();
|
||||
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap();
|
||||
|
||||
iter_par!(0, batch_size).for_each(|b| {
|
||||
iter_range_par!(0, batch_size).for_each(|b| {
|
||||
let lhs_slice = match batch_size_lhs == 1 {
|
||||
true => lhs_array.slice(s!(0, .., ..)),
|
||||
false => lhs_array.slice(s!(b, .., ..)),
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par,
|
||||
element::FloatNdArrayElement, iter_range_par, ops::padding::apply_padding_4d, run_par,
|
||||
sharing::UnsafeSharedRef, tensor::NdArrayTensor,
|
||||
};
|
||||
|
||||
|
@ -33,7 +33,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
|
|||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
|
@ -96,7 +96,7 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
|
|||
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
|
@ -159,7 +159,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
|
|||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
let b = k / channels;
|
||||
let c = k % channels;
|
||||
|
||||
|
|
|
@ -18,9 +18,25 @@ macro_rules! run_par {
|
|||
}};
|
||||
}
|
||||
|
||||
/// Macro for iterating over a range in parallel.
|
||||
/// Macro for iterating in parallel.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! iter_par {
|
||||
(
|
||||
$iter:expr
|
||||
) => {{
|
||||
#[cfg(feature = "std")]
|
||||
let output = $iter.into_par_iter();
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
let output = $iter;
|
||||
|
||||
output
|
||||
}};
|
||||
}
|
||||
|
||||
/// Macro for iterating over a range in parallel.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! iter_range_par {
|
||||
(
|
||||
$start:expr, $end:expr
|
||||
) => {{
|
||||
|
|
Loading…
Reference in New Issue