mirror of https://github.com/tracel-ai/burn.git
Perf/ndarray: Optimize `conv2d` operation (#747)
This commit is contained in:
parent
a47d23c3dd
commit
06157d3cde
|
@ -1,5 +1,5 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef,
|
element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
|
||||||
tensor::NdArrayTensor,
|
tensor::NdArrayTensor,
|
||||||
};
|
};
|
||||||
use burn_tensor::ElementConversion;
|
use burn_tensor::ElementConversion;
|
||||||
|
@ -19,7 +19,7 @@ pub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(
|
||||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||||
|
|
||||||
run_par!(|| {
|
run_par!(|| {
|
||||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||||
let b = k / channels;
|
let b = k / channels;
|
||||||
let c = k % channels;
|
let c = k % channels;
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ pub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(
|
||||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
|
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);
|
||||||
|
|
||||||
run_par!(|| {
|
run_par!(|| {
|
||||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||||
let b = k / channels;
|
let b = k / channels;
|
||||||
let c = k % channels;
|
let c = k % channels;
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
element::FloatNdArrayElement, iter_par, run_par, sharing::UnsafeSharedRef,
|
element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
|
||||||
tensor::NdArrayTensor,
|
tensor::NdArrayTensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ pub(crate) fn avg_pool2d<E: FloatNdArrayElement>(
|
||||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||||
|
|
||||||
run_par!(|| {
|
run_par!(|| {
|
||||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||||
let b = k / channels;
|
let b = k / channels;
|
||||||
let c = k % channels;
|
let c = k % channels;
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ pub(crate) fn avg_pool2d_backward<E: FloatNdArrayElement>(
|
||||||
let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad);
|
let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad);
|
||||||
|
|
||||||
run_par!(|| {
|
run_par!(|| {
|
||||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||||
let b = k / channels;
|
let b = k / channels;
|
||||||
let c = k % channels;
|
let c = k % channels;
|
||||||
|
|
||||||
|
|
|
@ -392,30 +392,36 @@ where
|
||||||
arg(tensor, dim, CmpType::Min)
|
arg(tensor, dim, CmpType::Min)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn clamp_min<const D: usize>(tensor: NdArrayTensor<E, D>, min: E) -> NdArrayTensor<E, D> {
|
pub fn clamp_min<const D: usize>(
|
||||||
let array = tensor.array.mapv(|x| match x < min {
|
mut tensor: NdArrayTensor<E, D>,
|
||||||
|
min: E,
|
||||||
|
) -> NdArrayTensor<E, D> {
|
||||||
|
tensor.array.mapv_inplace(|x| match x < min {
|
||||||
true => min,
|
true => min,
|
||||||
false => x,
|
false => x,
|
||||||
});
|
});
|
||||||
|
|
||||||
NdArrayTensor::new(array.into_shared())
|
tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn clamp_max<const D: usize>(tensor: NdArrayTensor<E, D>, max: E) -> NdArrayTensor<E, D> {
|
pub fn clamp_max<const D: usize>(
|
||||||
let array = tensor.array.mapv(|x| match x > max {
|
mut tensor: NdArrayTensor<E, D>,
|
||||||
|
max: E,
|
||||||
|
) -> NdArrayTensor<E, D> {
|
||||||
|
tensor.array.mapv_inplace(|x| match x > max {
|
||||||
true => max,
|
true => max,
|
||||||
false => x,
|
false => x,
|
||||||
});
|
});
|
||||||
|
|
||||||
NdArrayTensor::new(array.into_shared())
|
tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn clamp<const D: usize>(
|
pub fn clamp<const D: usize>(
|
||||||
tensor: NdArrayTensor<E, D>,
|
mut tensor: NdArrayTensor<E, D>,
|
||||||
min: E,
|
min: E,
|
||||||
max: E,
|
max: E,
|
||||||
) -> NdArrayTensor<E, D> {
|
) -> NdArrayTensor<E, D> {
|
||||||
let array = tensor.array.mapv(|x| match x < min {
|
tensor.array.mapv_inplace(|x| match x < min {
|
||||||
true => min,
|
true => min,
|
||||||
false => match x > max {
|
false => match x > max {
|
||||||
true => max,
|
true => max,
|
||||||
|
@ -423,7 +429,7 @@ where
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
NdArrayTensor::new(array.into_shared())
|
tensor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,20 +2,58 @@ use burn_tensor::{
|
||||||
ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions},
|
ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions},
|
||||||
ElementConversion,
|
ElementConversion,
|
||||||
};
|
};
|
||||||
use ndarray::{Array4, Dim};
|
use ndarray::{s, Array3, Array4, ArrayView2, ArrayViewMut2, Axis, Dim};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par,
|
element::FloatNdArrayElement, iter_par, iter_range_par, ops::padding::apply_padding_4d,
|
||||||
sharing::UnsafeSharedRef, tensor::NdArrayTensor,
|
run_par, sharing::UnsafeSharedRef, tensor::NdArrayTensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn conv2d_mad_inner<E: FloatNdArrayElement>(
|
||||||
|
mut output: ArrayViewMut2<E>,
|
||||||
|
x: ArrayView2<E>,
|
||||||
|
k: E,
|
||||||
|
k_xy: (usize, usize),
|
||||||
|
out_xy: (usize, usize),
|
||||||
|
stride: (usize, usize),
|
||||||
|
dilation: (usize, usize),
|
||||||
|
) {
|
||||||
|
let (kh, kw) = k_xy;
|
||||||
|
let (out_width, out_height) = out_xy;
|
||||||
|
let (stride_width, stride_height) = stride;
|
||||||
|
let (dilation_width, dilation_height) = dilation;
|
||||||
|
|
||||||
|
for oh in 0..out_height {
|
||||||
|
// Construct a sub-slice view of the input row.
|
||||||
|
// This is done upfront so that rustc does not have to emit bounds checks
|
||||||
|
// in the hot loop below.
|
||||||
|
let ir = x
|
||||||
|
.row(oh * stride_height + kh * dilation_height)
|
||||||
|
.to_slice()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Ditto. Construct a sub-slice view of the output row, and explicitly specify
|
||||||
|
// the bounds upfront as 0..out_width so that rustc can make the assumption
|
||||||
|
// that all accesses are in-bounds in the below loop.
|
||||||
|
let mut or = output.row_mut(oh);
|
||||||
|
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
|
||||||
|
|
||||||
|
#[allow(clippy::needless_range_loop)]
|
||||||
|
for ow in 0..out_width {
|
||||||
|
let iw = (ow * stride_width) + (kw * dilation_width);
|
||||||
|
or[ow] += ir[iw] * k;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
||||||
x: NdArrayTensor<E, 4>,
|
x: NdArrayTensor<E, 4>,
|
||||||
weight: NdArrayTensor<E, 4>,
|
weight: NdArrayTensor<E, 4>,
|
||||||
bias: Option<NdArrayTensor<E, 1>>,
|
bias: Option<NdArrayTensor<E, 1>>,
|
||||||
options: ConvOptions<2>,
|
options: ConvOptions<2>,
|
||||||
) -> NdArrayTensor<E, 4> {
|
) -> NdArrayTensor<E, 4> {
|
||||||
let [dilatation_height, dilatation_width] = options.dilation;
|
let [dilation_height, dilation_width] = options.dilation;
|
||||||
let [padding_height, padding_width] = options.padding;
|
let [padding_height, padding_width] = options.padding;
|
||||||
let [stride_height, stride_width] = options.stride;
|
let [stride_height, stride_width] = options.stride;
|
||||||
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
|
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
|
||||||
|
@ -25,59 +63,107 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
|
||||||
kernel_height,
|
kernel_height,
|
||||||
stride_height,
|
stride_height,
|
||||||
padding_height,
|
padding_height,
|
||||||
dilatation_height,
|
dilation_height,
|
||||||
in_height,
|
in_height,
|
||||||
);
|
);
|
||||||
let out_width = calculate_conv_output_size(
|
let out_width = calculate_conv_output_size(
|
||||||
kernel_width,
|
kernel_width,
|
||||||
stride_width,
|
stride_width,
|
||||||
padding_width,
|
padding_width,
|
||||||
dilatation_width,
|
dilation_width,
|
||||||
in_width,
|
in_width,
|
||||||
);
|
);
|
||||||
|
|
||||||
let x = apply_padding_4d(x, options.padding, 0i32.elem()).array;
|
let x = apply_padding_4d(x, options.padding, 0i32.elem()).array;
|
||||||
|
|
||||||
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width]));
|
// Convert inputs from dynamic indexes to static to improve perf.
|
||||||
|
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||||
|
let weights = weight.array.into_dimensionality::<ndarray::Ix4>().unwrap();
|
||||||
|
|
||||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width]));
|
||||||
|
|
||||||
run_par!(|| {
|
run_par!(|| {
|
||||||
iter_par!(0, batch_size * out_channels).for_each(|k| unsafe {
|
iter_par!(output.axis_iter_mut(Axis(0)))
|
||||||
|
.enumerate()
|
||||||
|
.for_each(
|
||||||
|
#[inline(never)]
|
||||||
|
|(k, mut output)| {
|
||||||
let b = k / out_channels;
|
let b = k / out_channels;
|
||||||
let oc = k % out_channels;
|
let oc = k % out_channels;
|
||||||
let g = k % options.groups;
|
let g = k % options.groups;
|
||||||
|
|
||||||
let output = unsafe_shared_out.get();
|
|
||||||
|
|
||||||
for ic in (in_channels * g)..(in_channels * (g + 1)) {
|
for ic in (in_channels * g)..(in_channels * (g + 1)) {
|
||||||
|
let weight_ic = ic - (g * in_channels);
|
||||||
|
|
||||||
|
let x = x.slice(s![b, ic, .., ..]);
|
||||||
|
let k = weights.slice(s![oc, weight_ic, .., ..]);
|
||||||
|
|
||||||
for kh in 0..kernel_height {
|
for kh in 0..kernel_height {
|
||||||
for kw in 0..kernel_width {
|
for kw in 0..kernel_width {
|
||||||
for oh in 0..out_height {
|
let k = k[[kh, kw]];
|
||||||
for ow in 0..out_width {
|
|
||||||
let ih = oh * stride_height + kh * dilatation_height;
|
|
||||||
let iw = ow * stride_width + kw * dilatation_width;
|
|
||||||
|
|
||||||
let weight_ic = ic - (g * in_channels);
|
// NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization
|
||||||
output[[b, oc, oh, ow]] +=
|
// in the case that the stride/dilation is 1.
|
||||||
x[[b, ic, ih, iw]] * weight.array[[oc, weight_ic, kh, kw]];
|
#[allow(clippy::if_same_then_else)]
|
||||||
}
|
if (1, 1, 1, 1)
|
||||||
|
== (
|
||||||
|
stride_width,
|
||||||
|
stride_height,
|
||||||
|
dilation_width,
|
||||||
|
dilation_height,
|
||||||
|
)
|
||||||
|
{
|
||||||
|
conv2d_mad_inner(
|
||||||
|
output.view_mut(),
|
||||||
|
x.view(),
|
||||||
|
k,
|
||||||
|
(kh, kw),
|
||||||
|
(out_width, out_height),
|
||||||
|
(stride_width, stride_height),
|
||||||
|
(dilation_width, dilation_height),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
conv2d_mad_inner(
|
||||||
|
output.view_mut(),
|
||||||
|
x.view(),
|
||||||
|
k,
|
||||||
|
(kh, kw),
|
||||||
|
(out_width, out_height),
|
||||||
|
(stride_width, stride_height),
|
||||||
|
(dilation_width, dilation_height),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(bias) = &bias {
|
if let Some(bias) = &bias {
|
||||||
|
let bias = bias.array[oc];
|
||||||
|
|
||||||
for oh in 0..out_height {
|
for oh in 0..out_height {
|
||||||
|
// Get a mutable slice reference to the row we're looping over.
|
||||||
|
// We explicitly define the bounds to 0..out_width so that rustc can make
|
||||||
|
// the assumption that all accesses are in-bounds.
|
||||||
|
let mut or = output.row_mut(oh);
|
||||||
|
let or = &mut or.as_slice_mut().unwrap()[0..out_width];
|
||||||
|
|
||||||
|
#[allow(clippy::needless_range_loop)]
|
||||||
for ow in 0..out_width {
|
for ow in 0..out_width {
|
||||||
output[[b, oc, oh, ow]] += bias.array[oc];
|
or[ow] += bias;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
},
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
NdArrayTensor::new(output.into_dyn().into_shared())
|
let output = output
|
||||||
|
.into_shape([batch_size, out_channels, out_height, out_width])
|
||||||
|
.unwrap()
|
||||||
|
.into_dyn()
|
||||||
|
.into_shared();
|
||||||
|
|
||||||
|
NdArrayTensor::new(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
||||||
|
@ -114,7 +200,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
|
||||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||||
|
|
||||||
run_par!(|| {
|
run_par!(|| {
|
||||||
iter_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
|
iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
|
||||||
let b = k / (out_channels * options.groups);
|
let b = k / (out_channels * options.groups);
|
||||||
let oc = k % out_channels;
|
let oc = k % out_channels;
|
||||||
let g = k % options.groups;
|
let g = k % options.groups;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||||
use crate::{iter_par, run_par, UnsafeSharedRef};
|
use crate::{iter_range_par, run_par, UnsafeSharedRef};
|
||||||
use burn_tensor::ElementConversion;
|
use burn_tensor::ElementConversion;
|
||||||
use burn_tensor::{ops::TensorOps, Shape};
|
use burn_tensor::{ops::TensorOps, Shape};
|
||||||
use ndarray::s;
|
use ndarray::s;
|
||||||
|
@ -58,7 +58,7 @@ fn general_matmul<E: FloatNdArrayElement>(
|
||||||
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap();
|
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap();
|
||||||
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap();
|
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap();
|
||||||
|
|
||||||
iter_par!(0, batch_size).for_each(|b| {
|
iter_range_par!(0, batch_size).for_each(|b| {
|
||||||
let lhs_slice = match batch_size_lhs == 1 {
|
let lhs_slice = match batch_size_lhs == 1 {
|
||||||
true => lhs_array.slice(s!(0, .., ..)),
|
true => lhs_array.slice(s!(0, .., ..)),
|
||||||
false => lhs_array.slice(s!(b, .., ..)),
|
false => lhs_array.slice(s!(b, .., ..)),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
element::FloatNdArrayElement, iter_par, ops::padding::apply_padding_4d, run_par,
|
element::FloatNdArrayElement, iter_range_par, ops::padding::apply_padding_4d, run_par,
|
||||||
sharing::UnsafeSharedRef, tensor::NdArrayTensor,
|
sharing::UnsafeSharedRef, tensor::NdArrayTensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
|
||||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||||
|
|
||||||
run_par!(|| {
|
run_par!(|| {
|
||||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||||
let b = k / channels;
|
let b = k / channels;
|
||||||
let c = k % channels;
|
let c = k % channels;
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
|
||||||
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
|
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
|
||||||
|
|
||||||
run_par!(|| {
|
run_par!(|| {
|
||||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||||
let b = k / channels;
|
let b = k / channels;
|
||||||
let c = k % channels;
|
let c = k % channels;
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
|
||||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||||
|
|
||||||
run_par!(|| {
|
run_par!(|| {
|
||||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||||
let b = k / channels;
|
let b = k / channels;
|
||||||
let c = k % channels;
|
let c = k % channels;
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,25 @@ macro_rules! run_par {
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Macro for iterating over a range in parallel.
|
/// Macro for iterating in parallel.
|
||||||
#[macro_export(local_inner_macros)]
|
#[macro_export(local_inner_macros)]
|
||||||
macro_rules! iter_par {
|
macro_rules! iter_par {
|
||||||
|
(
|
||||||
|
$iter:expr
|
||||||
|
) => {{
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
let output = $iter.into_par_iter();
|
||||||
|
|
||||||
|
#[cfg(not(feature = "std"))]
|
||||||
|
let output = $iter;
|
||||||
|
|
||||||
|
output
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Macro for iterating over a range in parallel.
|
||||||
|
#[macro_export(local_inner_macros)]
|
||||||
|
macro_rules! iter_range_par {
|
||||||
(
|
(
|
||||||
$start:expr, $end:expr
|
$start:expr, $end:expr
|
||||||
) => {{
|
) => {{
|
||||||
|
|
Loading…
Reference in New Issue