Feat/group conv (#306)

This commit is contained in:
Nathaniel Simard 2023-04-22 15:00:41 -04:00 committed by GitHub
parent 78ac09fb7a
commit c5e31b272f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1200 additions and 870 deletions

View File

@ -50,9 +50,7 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
x: ADTensor<B, 4>,
weight: ADTensor<B, 4>,
bias: Option<ADTensor<B, 1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
options: ConvOptions<2>,
) -> ADTensor<B, 4> {
#[derive(Debug)]
struct Conv2DWithBias;
@ -64,18 +62,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::TensorPrimitive<4>,
B::TensorPrimitive<4>,
B::TensorPrimitive<1>,
[usize; 2],
[usize; 2],
[usize; 2],
ConvOptions<2>,
);
fn backward(self, ops: Ops<Self::State, 3>, grads: &mut Gradients) {
let [node_x, node_weight, node_bias] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x, weight, bias, stride, padding, dilation) = ops.state;
let backward =
B::conv2d_backward(x, weight, Some(bias), stride, padding, dilation, grad);
let (x, weight, bias, options) = ops.state;
let backward = B::conv2d_backward(x, weight, Some(bias), grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node, backward.x_grad)
@ -90,20 +85,14 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
impl<B: Backend> Backward<B, 4, 2> for Conv2DNoBias {
type State = (
B::TensorPrimitive<4>,
B::TensorPrimitive<4>,
[usize; 2],
[usize; 2],
[usize; 2],
);
type State = (B::TensorPrimitive<4>, B::TensorPrimitive<4>, ConvOptions<2>);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let [node_x, node_weight] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x, weight, stride, padding, dilation) = ops.state;
let backward = B::conv2d_backward(x, weight, None, stride, padding, dilation, grad);
let (x, weight, options) = ops.state;
let backward = B::conv2d_backward(x, weight, None, grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node, backward.x_grad)
@ -128,26 +117,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
stride,
padding,
dilation,
),
B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
stride,
padding,
dilation,
options.clone(),
),
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
stride,
padding,
dilation,
options,
)),
}
}
@ -160,27 +138,13 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
(
x.primitive.clone(),
weight.primitive.clone(),
stride,
padding,
dilation,
options.clone(),
),
B::conv2d(
x.primitive,
weight.primitive,
None,
stride,
padding,
dilation,
B::conv2d(x.primitive, weight.primitive, None, options),
),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
None,
stride,
padding,
dilation,
)),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
}
}
}
}
@ -190,33 +154,16 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
_x: ADTensor<B, 4>,
_weight: ADTensor<B, 4>,
_bias: Option<ADTensor<B, 1>>,
_stride: [usize; 2],
_padding: [usize; 2],
_padding_out: [usize; 2],
_dilation: [usize; 2],
_options: ConvTransposeOptions<2>,
) -> ADTensor<B, 4> {
todo!("Transposed 2D convolution doesn't yet support backward.");
}
fn conv_transpose1d(
_x: ADTensor<B, 3>,
_weight: ADTensor<B, 3>,
_bias: Option<ADTensor<B, 1>>,
_stride: usize,
_padding: usize,
_padding_out: usize,
_dilation: usize,
) -> ADTensor<B, 3> {
todo!("Transposed 1D convolution doesn't yet support backward.");
}
fn conv1d(
x: ADTensor<B, 3>,
weight: ADTensor<B, 3>,
bias: Option<ADTensor<B, 1>>,
stride: usize,
padding: usize,
dilation: usize,
options: ConvOptions<1>,
) -> ADTensor<B, 3> {
#[derive(Debug)]
struct Conv1DWithBias;
@ -228,18 +175,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::TensorPrimitive<3>,
B::TensorPrimitive<3>,
B::TensorPrimitive<1>,
usize,
usize,
usize,
ConvOptions<1>,
);
fn backward(self, ops: Ops<Self::State, 3>, grads: &mut Gradients) {
let [node_x, node_weight, node_bias] = ops.parents;
let grad = grads.consume::<B, 3>(&ops.node);
let (x, weight, bias, stride, padding, dilation) = ops.state;
let backward =
B::conv1d_backward(x, weight, Some(bias), stride, padding, dilation, grad);
let (x, weight, bias, options) = ops.state;
let backward = B::conv1d_backward(x, weight, Some(bias), grad, options);
if let Some(node) = node_x {
grads.register::<B, 3>(node, backward.x_grad)
@ -254,20 +198,14 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
impl<B: Backend> Backward<B, 3, 2> for Conv1DNoBias {
type State = (
B::TensorPrimitive<3>,
B::TensorPrimitive<3>,
usize,
usize,
usize,
);
type State = (B::TensorPrimitive<3>, B::TensorPrimitive<3>, ConvOptions<1>);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let [node_x, node_weight] = ops.parents;
let grad = grads.consume::<B, 3>(&ops.node);
let (x, weight, stride, padding, dilation) = ops.state;
let backward = B::conv1d_backward(x, weight, None, stride, padding, dilation, grad);
let (x, weight, options) = ops.state;
let backward = B::conv1d_backward(x, weight, None, grad, options);
if let Some(node) = node_x {
grads.register::<B, 3>(node, backward.x_grad)
@ -291,26 +229,15 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
stride,
padding,
dilation,
),
B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
stride,
padding,
dilation,
options.clone(),
),
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
stride,
padding,
dilation,
options,
)),
}
}
@ -323,31 +250,26 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
(
x.primitive.clone(),
weight.primitive.clone(),
stride,
padding,
dilation,
options.clone(),
),
B::conv1d(
x.primitive,
weight.primitive,
None,
stride,
padding,
dilation,
B::conv1d(x.primitive, weight.primitive, None, options),
),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
None,
stride,
padding,
dilation,
)),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
}
}
}
}
}
fn conv_transpose1d(
_x: ADTensor<B, 3>,
_weight: ADTensor<B, 3>,
_bias: Option<ADTensor<B, 1>>,
_options: ConvTransposeOptions<1>,
) -> ADTensor<B, 3> {
todo!("Transposed 1D convolution doesn't yet support backward.");
}
fn max_pool2d(
x: ADTensor<B, 4>,

View File

@ -1,39 +1,31 @@
#[burn_tensor_testgen::testgen(ad_conv1d)]
mod tests {
use super::*;
use burn_tensor::{module::conv1d, Data};
use burn_tensor::{module::conv1d, ops::ConvOptions, Data, Shape};
#[test]
fn test_conv1d_basic() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 3,
channels_out: 3,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
length: 6,
groups: 1,
length: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
[
[6., 9., 9., 9., 9., 6.],
[6., 9., 9., 9., 9., 6.],
[6., 9., 9., 9., 9., 6.],
],
[
[6., 9., 9., 9., 9., 6.],
[6., 9., 9., 9., 9., 6.],
[6., 9., 9., 9., 9., 6.],
],
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
[[14., 24., 24., 18.], [26., 42., 42., 30.]],
]),
weight: TestTensor::from_floats([
[[10., 12., 10.], [10., 12., 10.], [10., 12., 10.]],
[[10., 12., 10.], [10., 12., 10.], [10., 12., 10.]],
[[10., 12., 10.], [10., 12., 10.], [10., 12., 10.]],
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
]),
bias: TestTensor::from_floats([12., 12., 12.]),
bias: TestTensor::from_floats([8., 8.]),
};
test.assert_grads(grads);
}
@ -48,19 +40,20 @@ mod tests {
padding: 1,
stride: 1,
dilation: 1,
length: 6,
groups: 1,
length: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
[[6., 9., 9., 9., 9., 6.], [6., 9., 9., 9., 9., 6.]],
[[6., 9., 9., 9., 9., 6.], [6., 9., 9., 9., 9., 6.]],
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
[[39., 63., 63., 45.], [57., 90., 90., 63.]],
]),
weight: TestTensor::from_floats([
[[10., 12., 10.], [10., 12., 10.]],
[[10., 12., 10.], [10., 12., 10.]],
[[10., 12., 10.], [10., 12., 10.]],
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
[[30., 44., 36.], [54., 76., 60.]],
]),
bias: TestTensor::from_floats([12., 12., 12.]),
bias: TestTensor::from_floats([8., 8., 8.]),
};
test.assert_grads(grads);
}
@ -75,18 +68,19 @@ mod tests {
padding: 2,
stride: 1,
dilation: 1,
length: 6,
groups: 1,
length: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
[[6., 6., 6., 6., 6., 6.], [6., 6., 6., 6., 6., 6.]],
[[6., 6., 6., 6., 6., 6.], [6., 6., 6., 6., 6., 6.]],
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
[[24., 24., 24., 24.], [42., 42., 42., 42.]],
]),
weight: TestTensor::from_floats([
[[12., 12., 12.], [12., 12., 12.]],
[[12., 12., 12.], [12., 12., 12.]],
[[44., 44., 44.], [76., 76., 76.]],
[[44., 44., 44.], [76., 76., 76.]],
]),
bias: TestTensor::from_floats([16., 16.]),
bias: TestTensor::from_floats([12., 12.]),
};
test.assert_grads(grads);
}
@ -101,16 +95,17 @@ mod tests {
padding: 1,
stride: 2,
dilation: 1,
groups: 1,
length: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
[[2., 4., 2., 2.], [2., 4., 2., 2.]],
[[2., 4., 2., 2.], [2., 4., 2., 2.]],
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
[[8., 16., 8., 10.], [14., 28., 14., 16.]],
]),
weight: TestTensor::from_floats([
[[2., 4., 4.], [2., 4., 4.]],
[[2., 4., 4.], [2., 4., 4.]],
[[10., 20., 24.], [18., 36., 40.]],
[[10., 20., 24.], [18., 36., 40.]],
]),
bias: TestTensor::from_floats([4., 4.]),
};
@ -127,22 +122,47 @@ mod tests {
padding: 1,
stride: 1,
dilation: 2,
groups: 1,
length: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
[[2., 2., 2., 2.], [2., 2., 2., 2.]],
[[2., 2., 2., 2.], [2., 2., 2., 2.]],
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
[[6., 8., 8., 10.], [12., 14., 14., 16.]],
]),
weight: TestTensor::from_floats([
[[2., 4., 2.], [2., 4., 2.]],
[[2., 4., 2.], [2., 4., 2.]],
[[8., 22., 14.], [16., 38., 22.]],
[[8., 22., 14.], [16., 38., 22.]],
]),
bias: TestTensor::from_floats([4., 4.]),
};
test.assert_grads(grads);
}
#[test]
fn test_conv1d_groups() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 2,
length: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
[[1., 3., 3., 3.], [7., 12., 12., 9.]],
]),
weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]]),
bias: TestTensor::from_floats([8., 8.]),
};
test.assert_grads(grads);
}
struct Conv1dTestCase {
batch_size: usize,
channels_in: usize,
@ -151,6 +171,7 @@ mod tests {
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
length: usize,
}
@ -162,19 +183,38 @@ mod tests {
impl Conv1dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let weight =
TestADTensor::ones([self.channels_out, self.channels_in, self.kernel_size])
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size,
]);
let weight = TestADTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements())
.reshape(shape_weight)
.into_data()
.convert(),
)
.require_grad();
let bias = TestADTensor::ones([self.channels_out]).require_grad();
let x =
TestADTensor::ones([self.batch_size, self.channels_in, self.length]).require_grad();
let bias = TestADTensor::from_data(
TestTensorInt::arange(0..self.channels_out)
.into_data()
.convert(),
)
.require_grad();
let x = TestADTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
)
.require_grad();
let output = conv1d(
x.clone(),
weight.clone(),
Some(bias.clone()),
self.stride,
self.padding,
self.dilation,
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
);
let grads = output.backward();

View File

@ -1,14 +1,14 @@
#[burn_tensor_testgen::testgen(ad_conv2d)]
mod tests {
use super::*;
use burn_tensor::{module::conv2d, Data};
use burn_tensor::{module::conv2d, ops::ConvOptions, Data, Shape};
#[test]
fn test_conv2d_basic() {
let test = Conv2dTestCase {
batch_size: 2,
channels_in: 3,
channels_out: 3,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
@ -17,82 +17,52 @@ mod tests {
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
groups: 1,
height: 4,
width: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
[
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
[88., 138., 138., 96.],
[150., 234., 234., 162.],
[150., 234., 234., 162.],
[112., 174., 174., 120.],
],
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
],
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
[160., 246., 246., 168.],
[258., 396., 396., 270.],
[258., 396., 396., 270.],
[184., 282., 282., 192.],
],
],
[
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
[88., 138., 138., 96.],
[150., 234., 234., 162.],
[150., 234., 234., 162.],
[112., 174., 174., 120.],
],
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
],
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
[160., 246., 246., 168.],
[258., 396., 396., 270.],
[258., 396., 396., 270.],
[184., 282., 282., 192.],
],
],
]),
weight: TestTensor::from_floats([
[
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
],
[
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
]),
bias: TestTensor::from_floats([72., 72., 72.]),
bias: TestTensor::from_floats([32., 32.]),
};
test.assert_grads(grads);
}
@ -111,63 +81,56 @@ mod tests {
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
groups: 1,
height: 4,
width: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
[
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
[240., 369., 369., 252.],
[387., 594., 594., 405.],
[387., 594., 594., 405.],
[276., 423., 423., 288.],
],
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
[348., 531., 531., 360.],
[549., 837., 837., 567.],
[549., 837., 837., 567.],
[384., 585., 585., 396.],
],
],
[
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
[240., 369., 369., 252.],
[387., 594., 594., 405.],
[387., 594., 594., 405.],
[276., 423., 423., 288.],
],
[
[12., 18., 18., 18., 18., 12.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[18., 27., 27., 27., 27., 18.],
[12., 18., 18., 18., 18., 12.],
[348., 531., 531., 360.],
[549., 837., 837., 567.],
[549., 837., 837., 567.],
[384., 585., 585., 396.],
],
],
]),
weight: TestTensor::from_floats([
[
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
[
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[50., 60., 50.], [60., 72., 60.], [50., 60., 50.]],
[[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],
[[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],
],
]),
bias: TestTensor::from_floats([72., 72., 72.]),
bias: TestTensor::from_floats([32., 32., 32.]),
};
test.assert_grads(grads);
}
@ -175,7 +138,7 @@ mod tests {
#[test]
fn test_conv2d_different_kernel_size() {
let test = Conv2dTestCase {
batch_size: 2,
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
@ -186,75 +149,52 @@ mod tests {
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
groups: 1,
height: 4,
width: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
x: TestTensor::from_floats([[
[
[
[8., 12., 16., 16., 12., 8.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[8., 12., 16., 16., 12., 8.],
[116., 180., 192., 132.],
[198., 306., 324., 222.],
[198., 306., 324., 222.],
[148., 228., 240., 164.],
],
[
[8., 12., 16., 16., 12., 8.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[8., 12., 16., 16., 12., 8.],
[212., 324., 336., 228.],
[342., 522., 540., 366.],
[342., 522., 540., 366.],
[244., 372., 384., 260.],
],
],
[
[
[8., 12., 16., 16., 12., 8.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[8., 12., 16., 16., 12., 8.],
],
[
[8., 12., 16., 16., 12., 8.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[12., 18., 24., 24., 18., 12.],
[8., 12., 16., 16., 12., 8.],
],
],
]),
]]),
weight: TestTensor::from_floats([
[
[
[40., 50., 50., 40.],
[48., 60., 60., 48.],
[40., 50., 50., 40.],
[27., 45., 54., 39.],
[52., 84., 96., 68.],
[51., 81., 90., 63.],
],
[
[40., 50., 50., 40.],
[48., 60., 60., 48.],
[40., 50., 50., 40.],
[123., 189., 198., 135.],
[180., 276., 288., 196.],
[147., 225., 234., 159.],
],
],
[
[
[40., 50., 50., 40.],
[48., 60., 60., 48.],
[40., 50., 50., 40.],
[27., 45., 54., 39.],
[52., 84., 96., 68.],
[51., 81., 90., 63.],
],
[
[40., 50., 50., 40.],
[48., 60., 60., 48.],
[40., 50., 50., 40.],
[123., 189., 198., 135.],
[180., 276., 288., 196.],
[147., 225., 234., 159.],
],
],
]),
bias: TestTensor::from_floats([60., 60.]),
bias: TestTensor::from_floats([12., 12.]),
};
test.assert_grads(grads);
}
@ -262,7 +202,7 @@ mod tests {
#[test]
fn test_conv2d_different_padding() {
let test = Conv2dTestCase {
batch_size: 2,
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
@ -273,59 +213,36 @@ mod tests {
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
groups: 1,
height: 4,
width: 4,
};
let grads = Grads {
x: TestTensor::from_floats([
x: TestTensor::from_floats([[
[
[
[12., 12., 12., 12., 12., 12.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[12., 12., 12., 12., 12., 12.],
[138., 138., 138., 138.],
[234., 234., 234., 234.],
[234., 234., 234., 234.],
[174., 174., 174., 174.],
],
[
[12., 12., 12., 12., 12., 12.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[12., 12., 12., 12., 12., 12.],
[246., 246., 246., 246.],
[396., 396., 396., 396.],
[396., 396., 396., 396.],
[282., 282., 282., 282.],
],
],
[
[
[12., 12., 12., 12., 12., 12.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[12., 12., 12., 12., 12., 12.],
],
[
[12., 12., 12., 12., 12., 12.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[18., 18., 18., 18., 18., 18.],
[12., 12., 12., 12., 12., 12.],
],
],
]),
]]),
weight: TestTensor::from_floats([
[
[[60., 60., 60.], [72., 72., 72.], [60., 60., 60.]],
[[60., 60., 60.], [72., 72., 72.], [60., 60., 60.]],
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
],
[
[[60., 60., 60.], [72., 72., 72.], [60., 60., 60.]],
[[60., 60., 60.], [72., 72., 72.], [60., 60., 60.]],
[[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],
[[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],
],
]),
bias: TestTensor::from_floats([96., 96.]),
bias: TestTensor::from_floats([24., 24.]),
};
test.assert_grads(grads);
}
@ -333,7 +250,7 @@ mod tests {
#[test]
fn test_conv2d_different_width() {
let test = Conv2dTestCase {
batch_size: 2,
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
@ -344,59 +261,36 @@ mod tests {
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
groups: 1,
height: 4,
width: 5,
};
let grads = Grads {
x: TestTensor::from_floats([
x: TestTensor::from_floats([[
[
[
[8., 12., 12., 12., 8.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[8., 12., 12., 12., 8.],
[88., 138., 138., 138., 96.],
[150., 234., 234., 234., 162.],
[150., 234., 234., 234., 162.],
[112., 174., 174., 174., 120.],
],
[
[8., 12., 12., 12., 8.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[8., 12., 12., 12., 8.],
[160., 246., 246., 246., 168.],
[258., 396., 396., 396., 270.],
[258., 396., 396., 396., 270.],
[184., 282., 282., 282., 192.],
],
],
[
[
[8., 12., 12., 12., 8.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[8., 12., 12., 12., 8.],
],
[
[8., 12., 12., 12., 8.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[12., 18., 18., 18., 12.],
[8., 12., 12., 12., 8.],
],
],
]),
]]),
weight: TestTensor::from_floats([
[
[[40., 50., 40.], [48., 60., 48.], [40., 50., 40.]],
[[40., 50., 40.], [48., 60., 48.], [40., 50., 40.]],
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
],
[
[[40., 50., 40.], [48., 60., 48.], [40., 50., 40.]],
[[40., 50., 40.], [48., 60., 48.], [40., 50., 40.]],
[[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],
[[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],
],
]),
bias: TestTensor::from_floats([60., 60.]),
bias: TestTensor::from_floats([20., 20.]),
};
test.assert_grads(grads);
}
@ -404,7 +298,7 @@ mod tests {
#[test]
fn test_conv2d_stride_2() {
let test = Conv2dTestCase {
batch_size: 2,
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
@ -415,67 +309,40 @@ mod tests {
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
height: 8,
width: 8,
groups: 1,
height: 6,
width: 6,
};
let grads = Grads {
x: TestTensor::from_floats([
x: TestTensor::from_floats([[
[
[
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[26., 52., 26., 52., 26., 28.],
[52., 104., 52., 104., 52., 56.],
[26., 52., 26., 52., 26., 28.],
[52., 104., 52., 104., 52., 56.],
[26., 52., 26., 52., 26., 28.],
[32., 64., 32., 64., 32., 34.],
],
[
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[44., 88., 44., 88., 44., 46.],
[88., 176., 88., 176., 88., 92.],
[44., 88., 44., 88., 44., 46.],
[88., 176., 88., 176., 88., 92.],
[44., 88., 44., 88., 44., 46.],
[50., 100., 50., 100., 50., 52.],
],
],
[
[
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[2., 4., 2., 4., 2., 4., 2., 2.],
],
[
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[4., 8., 4., 8., 4., 8., 4., 4.],
[2., 4., 2., 4., 2., 4., 2., 2.],
[2., 4., 2., 4., 2., 4., 2., 2.],
],
],
]),
]]),
weight: TestTensor::from_floats([
[
[[18., 24., 24.], [24., 32., 32.], [24., 32., 32.]],
[[18., 24., 24.], [24., 32., 32.], [24., 32., 32.]],
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
],
[
[[18., 24., 24.], [24., 32., 32.], [24., 32., 32.]],
[[18., 24., 24.], [24., 32., 32.], [24., 32., 32.]],
[[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],
[[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],
],
]),
bias: TestTensor::from_floats([32., 32.]),
bias: TestTensor::from_floats([9., 9.]),
};
test.assert_grads(grads);
}
@ -483,7 +350,7 @@ mod tests {
#[test]
fn test_conv2d_different_stride() {
let test = Conv2dTestCase {
batch_size: 2,
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
@ -494,67 +361,252 @@ mod tests {
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 8,
width: 8,
};
let grads = Grads {
x: TestTensor::from_floats([
x: TestTensor::from_floats([[
[
[
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
[38., 60., 60., 60., 60., 60., 60., 42.],
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
[38., 60., 60., 60., 60., 60., 60., 42.],
[50., 78., 78., 78., 78., 78., 78., 54.],
[62., 96., 96., 96., 96., 96., 96., 66.],
],
[
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
[74., 114., 114., 114., 114., 114., 114., 78.],
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
[74., 114., 114., 114., 114., 114., 114., 78.],
[86., 132., 132., 132., 132., 132., 132., 90.],
[98., 150., 150., 150., 150., 150., 150., 102.],
],
],
[
[
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
],
[
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
[4., 6., 6., 6., 6., 6., 6., 4.],
],
],
]),
]]),
weight: TestTensor::from_floats([
[
[[28., 32., 28.], [42., 48., 42.], [42., 48., 42.]],
[[28., 32., 28.], [42., 48., 42.], [42., 48., 42.]],
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
[
[1330., 1528., 1344.],
[1911., 2196., 1932.],
[2079., 2388., 2100.],
],
],
[
[[28., 32., 28.], [42., 48., 42.], [42., 48., 42.]],
[[28., 32., 28.], [42., 48., 42.], [42., 48., 42.]],
[[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],
[
[1330., 1528., 1344.],
[1911., 2196., 1932.],
[2079., 2388., 2100.],
],
],
]),
bias: TestTensor::from_floats([48., 48.]),
bias: TestTensor::from_floats([24., 24.]),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_dilation_2() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 2,
dilation_2: 2,
groups: 1,
height: 6,
width: 6,
};
let grads = Grads {
x: TestTensor::from_floats([[
[
[18., 38., 38., 42., 42., 22.],
[42., 88., 88., 96., 96., 50.],
[42., 88., 88., 96., 96., 50.],
[54., 112., 112., 120., 120., 62.],
[54., 112., 112., 120., 120., 62.],
[30., 62., 62., 66., 66., 34.],
],
[
[36., 74., 74., 78., 78., 40.],
[78., 160., 160., 168., 168., 86.],
[78., 160., 160., 168., 168., 86.],
[90., 184., 184., 192., 192., 98.],
[90., 184., 184., 192., 192., 98.],
[48., 98., 98., 102., 102., 52.],
],
]]),
weight: TestTensor::from_floats([
[
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
],
[
[[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],
[[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],
],
]),
bias: TestTensor::from_floats([16., 16.]),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_different_dilation() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
stride_1: 1,
stride_2: 1,
dilation_1: 2,
dilation_2: 3,
groups: 1,
height: 6,
width: 6,
};
let grads = Grads {
x: TestTensor::from_floats([[
[
[18., 0., 20., 20., 0., 22.],
[42., 0., 46., 46., 0., 50.],
[42., 0., 46., 46., 0., 50.],
[54., 0., 58., 58., 0., 62.],
[54., 0., 58., 58., 0., 62.],
[30., 0., 32., 32., 0., 34.],
],
[
[36., 0., 38., 38., 0., 40.],
[78., 0., 82., 82., 0., 86.],
[78., 0., 82., 82., 0., 86.],
[90., 0., 94., 94., 0., 98.],
[90., 0., 94., 94., 0., 98.],
[48., 0., 50., 50., 0., 52.],
],
]]),
weight: TestTensor::from_floats([
[
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
],
[
[[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],
[[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],
],
]),
bias: TestTensor::from_floats([8., 8.]),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 5,
width: 5,
};
let grads = Grads {
x: TestTensor::from_floats([[
[
[0., 1., 3., 3., 2.],
[3., 8., 15., 12., 7.],
[9., 21., 36., 27., 15.],
[9., 20., 33., 24., 13.],
[6., 13., 21., 15., 8.],
],
[
[9., 19., 30., 21., 11.],
[21., 44., 69., 48., 25.],
[36., 75., 117., 81., 42.],
[27., 56., 87., 60., 31.],
[15., 31., 48., 33., 17.],
],
]]),
weight: TestTensor::from_floats([
[[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]],
[[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]],
]),
bias: TestTensor::from_floats([9., 9.]),
};
test.assert_grads(grads);
}
#[test]
fn test_conv2d_groups_different_channels() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 3,
channels_out: 6,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 3,
height: 4,
width: 4,
};
let grads = Grads {
x: TestTensor::from_floats([[
[
[9., 20., 24., 13.],
[24., 52., 60., 32.],
[36., 76., 84., 44.],
[21., 44., 48., 25.],
],
[
[45., 92., 96., 49.],
[96., 196., 204., 104.],
[108., 220., 228., 116.],
[57., 116., 120., 61.],
],
[
[81., 164., 168., 85.],
[168., 340., 348., 176.],
[180., 364., 372., 188.],
[93., 188., 192., 97.],
],
]]),
weight: TestTensor::from_floats([
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
[[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
[[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
[[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],
]),
bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.]),
};
test.assert_grads(grads);
}
@ -573,28 +625,38 @@ mod tests {
stride_2: 2,
dilation_1: 2,
dilation_2: 3,
groups: 1,
height: 4,
width: 5,
};
let grads = Grads {
x: TestTensor::from_floats([[
[
[3., 3., 0., 3., 3.],
[6., 6., 0., 6., 6.],
[6., 6., 0., 6., 6.],
[3., 3., 0., 3., 3.],
[36., 39., 0., 39., 42.],
[81., 87., 0., 87., 93.],
[81., 87., 0., 87., 93.],
[45., 48., 0., 48., 51.],
],
[
[3., 3., 0., 3., 3.],
[6., 6., 0., 6., 6.],
[6., 6., 0., 6., 6.],
[3., 3., 0., 3., 3.],
[54., 57., 0., 57., 60.],
[117., 123., 0., 123., 129.],
[117., 123., 0., 123., 129.],
[63., 66., 0., 66., 69.],
],
]]),
weight: TestTensor::from_floats([
[[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]],
[[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]],
[[[3., 6., 3.], [3., 6., 3.]], [[3., 6., 3.], [3., 6., 3.]]],
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
[
[[15., 42., 27.], [30., 72., 42.]],
[[75., 162., 87.], [90., 192., 102.]],
],
]),
bias: TestTensor::from_floats([8., 8., 8.]),
};
@ -613,6 +675,7 @@ mod tests {
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
groups: usize,
height: usize,
width: usize,
}
@ -625,24 +688,43 @@ mod tests {
impl Conv2dTestCase {
fn assert_grads(self, expected_grads: Grads) {
let weight = TestADTensor::ones([
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in,
self.channels_in / self.groups,
self.kernel_size_1,
self.kernel_size_2,
])
]);
let weight = TestADTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements())
.reshape(shape_weight)
.into_data()
.convert(),
)
.require_grad();
let bias = TestADTensor::ones([self.channels_out]).require_grad();
let x =
TestADTensor::ones([self.batch_size, self.channels_in, self.height, self.width])
let bias = TestADTensor::from_data(
TestTensorInt::arange(0..self.channels_out)
.into_data()
.convert(),
)
.require_grad();
let x = TestADTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
)
.require_grad();
let output = conv2d(
x.clone(),
weight.clone(),
Some(bias.clone()),
ConvOptions::new(
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.dilation_1, self.dilation_2],
self.groups,
),
);
let grads = output.backward();

View File

@ -8,6 +8,7 @@ use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::module::conv1d;
use burn_tensor::ops::conv::calculate_conv_padding;
use burn_tensor::ops::ConvOptions;
use libm::sqrt;
@ -26,6 +27,9 @@ pub struct Conv1dConfig {
/// Spacing between kernel elements.
#[config(default = "1")]
pub dilation: usize,
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
#[config(default = "Conv1dPaddingConfig::Valid")]
pub padding: Conv1dPaddingConfig,
@ -65,6 +69,7 @@ pub struct Conv1d<B: Backend> {
stride: usize,
kernel_size: usize,
dilation: usize,
groups: usize,
padding: Conv1dPaddingConfig,
}
@ -95,6 +100,7 @@ impl Conv1dConfig {
kernel_size: self.kernel_size,
padding: self.padding.clone(),
dilation: self.dilation,
groups: self.groups,
}
}
/// Initialize a new [conv1d](Conv1d) module with a [record](Conv1dRecord).
@ -106,6 +112,7 @@ impl Conv1dConfig {
kernel_size: self.kernel_size,
padding: self.padding.clone(),
dilation: self.dilation,
groups: self.groups,
}
}
}
@ -133,9 +140,7 @@ impl<B: Backend> Conv1d<B> {
input,
self.weight.val(),
self.bias.as_ref().map(|bias| bias.val()),
self.stride,
padding,
self.dilation,
ConvOptions::new([self.stride], [padding], [self.dilation], self.groups),
)
}
}

View File

@ -8,6 +8,7 @@ use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::module::conv2d;
use burn_tensor::ops::conv::calculate_conv_padding;
use burn_tensor::ops::ConvOptions;
use libm::sqrt;
@ -24,6 +25,9 @@ pub struct Conv2dConfig {
/// Spacing between kernel elements.
#[config(default = "[1, 1]")]
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub groups: usize,
/// The padding configuration.
#[config(default = "Conv2dPaddingConfig::Valid")]
pub padding: Conv2dPaddingConfig,
@ -63,6 +67,7 @@ pub struct Conv2d<B: Backend> {
stride: [usize; 2],
kernel_size: [usize; 2],
dilation: [usize; 2],
groups: usize,
padding: Conv2dPaddingConfig,
}
@ -98,6 +103,7 @@ impl Conv2dConfig {
kernel_size: self.kernel_size,
dilation: self.dilation,
padding: self.padding.clone(),
groups: self.groups,
}
}
@ -110,6 +116,7 @@ impl Conv2dConfig {
dilation: self.dilation,
kernel_size: self.kernel_size,
padding: self.padding.clone(),
groups: self.groups,
}
}
}
@ -130,9 +137,7 @@ impl<B: Backend> Conv2d<B> {
input,
self.weight.val(),
self.bias.as_ref().map(|bias| bias.val()),
self.stride,
padding,
self.dilation,
ConvOptions::new(self.stride, padding, self.dilation, self.groups),
)
}
}

View File

@ -15,6 +15,7 @@ pub(crate) trait NdArrayElement:
+ ndarray::ScalarOperand
+ ExpElement
+ num_traits::FromPrimitive
+ core::ops::AddAssign
+ core::cmp::PartialEq
+ core::cmp::PartialOrd<Self>
{

View File

@ -260,7 +260,7 @@ where
for (i, index) in indexes.iter().enumerate() {
let index = *index as usize;
tensor[[b, index]] = tensor[[b, index]] + value[[b, i]];
tensor[[b, index]] += value[[b, i]];
}
}
@ -351,7 +351,7 @@ where
let mut view = output_array.index_axis_mut(Axis(dim), index as usize);
let value = value.array.index_axis(Axis(0), index_value);
view.zip_mut_with(&value, |a, b| *a = *a + *b);
view.zip_mut_with(&value, |a, b| *a += *b);
}
NdArrayTensor::new(output_array.into_shared())

View File

@ -1,4 +1,7 @@
use burn_tensor::{ops::conv::calculate_conv_output_size, ElementConversion};
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions, ConvTransposeOptions},
ElementConversion,
};
use ndarray::{Array4, Dim};
use crate::{
@ -10,13 +13,11 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
dilatation: [usize; 2],
options: ConvOptions<2>,
) -> NdArrayTensor<E, 4> {
let [dilatation_height, dilatation_width] = dilatation;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [dilatation_height, dilatation_width] = options.dilation;
let [padding_height, padding_width] = options.padding;
let [stride_height, stride_width] = options.stride;
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims;
@ -35,7 +36,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
in_width,
);
let x = apply_padding_4d(x, padding, 0i32.elem()).array;
let x = apply_padding_4d(x, options.padding, 0i32.elem()).array;
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width]));
@ -45,10 +46,11 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
iter_par!(0, batch_size * out_channels).for_each(|k| unsafe {
let b = k / out_channels;
let oc = k % out_channels;
let g = k % options.groups;
let output = unsafe_shared_out.get();
for ic in 0..in_channels {
for ic in (in_channels * g)..(in_channels * (g + 1)) {
for kh in 0..kernel_height {
for kw in 0..kernel_width {
for oh in 0..out_height {
@ -56,8 +58,9 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
let ih = oh * stride_height + kh * dilatation_height;
let iw = ow * stride_width + kw * dilatation_width;
output[[b, oc, oh, ow]] = output[[b, oc, oh, ow]]
+ x[[b, ic, ih, iw]] * weight.array[[oc, ic, kh, kw]];
let weight_ic = ic - (g * in_channels);
output[[b, oc, oh, ow]] +=
x[[b, ic, ih, iw]] * weight.array[[oc, weight_ic, kh, kw]];
}
}
}
@ -67,7 +70,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement>(
if let Some(bias) = &bias {
for oh in 0..out_height {
for ow in 0..out_width {
output[[b, oc, oh, ow]] = output[[b, oc, oh, ow]] + bias.array[oc];
output[[b, oc, oh, ow]] += bias.array[oc];
}
}
}
@ -81,15 +84,12 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
out_padding: [usize; 2],
dilation: [usize; 2],
options: ConvTransposeOptions<2>,
) -> NdArrayTensor<E, 4> {
let [dilation_height, dilation_width] = dilation;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [out_padding_height, out_padding_width] = out_padding;
let [dilation_height, dilation_width] = options.dilation;
let [padding_height, padding_width] = options.padding;
let [stride_height, stride_width] = options.stride;
let [out_padding_height, out_padding_width] = options.padding_out;
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
let [in_channels, out_channels, kernel_height, kernel_width] = weight.shape().dims;
@ -104,18 +104,28 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
+ 1;
let x = x.array;
let mut output = Array4::zeros(Dim([batch_size, out_channels, out_height, out_width]));
let mut output = Array4::zeros(Dim([
batch_size,
out_channels * options.groups,
out_height,
out_width,
]));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_par!(0, batch_size * out_channels).for_each(|k| unsafe {
let b = k / out_channels;
iter_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {
let b = k / (out_channels * options.groups);
let oc = k % out_channels;
let g = k % options.groups;
let output = unsafe_shared_out.get();
for ic in 0..in_channels {
let oc_out = oc + (out_channels * g);
let ic_start = g * (in_channels / options.groups);
let ic_end = ic_start + in_channels / options.groups;
for ic in ic_start..ic_end {
for ih in 0..in_height {
for iw in 0..in_width {
for kh in 0..kernel_height {
@ -134,8 +144,8 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
let oh = oh - padding_height;
let ow = ow - padding_width;
output[[b, oc, oh, ow]] = output[[b, oc, oh, ow]]
+ x[[b, ic, ih, iw]] * weight.array[[ic, oc, kh, kw]];
output[[b, oc_out, oh, ow]] +=
x[[b, ic, ih, iw]] * weight.array[[ic, oc, kh, kw]];
}
}
}
@ -145,7 +155,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
if let Some(bias) = &bias {
for oh in 0..out_height {
for ow in 0..out_width {
output[[b, oc, oh, ow]] = output[[b, oc, oh, ow]] + bias.array[oc];
output[[b, oc_out, oh, ow]] += bias.array[oc_out];
}
}
}

View File

@ -160,7 +160,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
let index_h = index as usize / width_x;
let index_w = index as usize % width_x;
output[[b, c, index_h, index_w]] = output[[b, c, index_h, index_w]] + grad;
output[[b, c, index_h, index_w]] += grad;
}
}
});

View File

@ -73,23 +73,18 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
options: ConvOptions<2>,
) -> NdArrayTensor<E, 4> {
conv2d(x, weight, bias, stride, padding, dilation)
conv2d(x, weight, bias, options)
}
fn conv_transpose2d(
x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
padding_out: [usize; 2],
dilation: [usize; 2],
options: ConvTransposeOptions<2>,
) -> NdArrayTensor<E, 4> {
conv_transpose2d(x, weight, bias, stride, padding, padding_out, dilation)
conv_transpose2d(x, weight, bias, options)
}
fn max_pool2d(

View File

@ -1,5 +1,7 @@
use crate::{element::TchElement, TchBackend, TchTensor};
use burn_tensor::ops::{MaxPool2dBackward, MaxPool2dWithIndexes, ModuleOps};
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndexes, ModuleOps,
};
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
fn embedding(weights: TchTensor<E, 2>, indexes: TchTensor<i64, 2>) -> TchTensor<E, 3> {
@ -30,18 +32,16 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
x: TchTensor<E, 3>,
weight: TchTensor<E, 3>,
bias: Option<TchTensor<E, 1>>,
stride: usize,
padding: usize,
dilation: usize,
options: ConvOptions<1>,
) -> TchTensor<E, 3> {
let tensor = tch::Tensor::conv1d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
&[stride as i64],
&[padding as i64],
&[dilation as i64],
1,
&options.stride.map(|i| i as i64),
&options.padding.map(|i| i as i64),
&options.dilation.map(|i| i as i64),
options.groups as i64,
);
TchTensor::new(tensor)
@ -51,18 +51,16 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
x: TchTensor<E, 4>,
weight: TchTensor<E, 4>,
bias: Option<TchTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
options: ConvOptions<2>,
) -> TchTensor<E, 4> {
let tensor = tch::Tensor::conv2d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
&[stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64],
&[dilation[0] as i64, dilation[1] as i64],
1,
&options.stride.map(|i| i as i64),
&options.padding.map(|i| i as i64),
&options.dilation.map(|i| i as i64),
options.groups as i64,
);
TchTensor::new(tensor)
@ -72,20 +70,17 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
x: TchTensor<E, 4>,
weight: TchTensor<E, 4>,
bias: Option<TchTensor<E, 1>>,
stride: [usize; 2],
padding: [usize; 2],
padding_out: [usize; 2],
dilation: [usize; 2],
options: ConvTransposeOptions<2>,
) -> TchTensor<E, 4> {
let tensor = tch::Tensor::conv_transpose2d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
&[stride[0] as i64, stride[1] as i64],
&[padding[0] as i64, padding[1] as i64],
&[padding_out[0] as i64, padding_out[1] as i64],
1,
&[dilation[0] as i64, dilation[1] as i64],
&options.stride.map(|i| i as i64),
&options.padding.map(|i| i as i64),
&options.padding_out.map(|i| i as i64),
options.groups as i64,
&options.dilation.map(|i| i as i64),
);
TchTensor::new(tensor)
@ -95,20 +90,17 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
x: TchTensor<E, 3>,
weight: TchTensor<E, 3>,
bias: Option<TchTensor<E, 1>>,
stride: usize,
padding: usize,
padding_out: usize,
dilation: usize,
options: ConvTransposeOptions<1>,
) -> TchTensor<E, 3> {
let tensor = tch::Tensor::conv_transpose1d(
&x.tensor,
&weight.tensor,
bias.map(|t| t.tensor),
&[stride as i64],
&[padding as i64],
&[padding_out as i64],
1,
&[dilation as i64],
&options.stride.map(|i| i as i64),
&options.padding.map(|i| i as i64),
&options.padding_out.map(|i| i as i64),
options.groups as i64,
&options.dilation.map(|i| i as i64),
);
TchTensor::new(tensor)

View File

@ -1,4 +1,8 @@
use crate::{backend::Backend, Int, Tensor};
use crate::{
backend::Backend,
ops::{ConvOptions, ConvTransposeOptions},
Int, Tensor,
};
/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
pub fn embedding<B>(weights: Tensor<B, 2>, indexes: Tensor<B, 2, Int>) -> Tensor<B, 3>
@ -13,9 +17,7 @@ pub fn conv1d<B>(
x: Tensor<B, 3>,
weight: Tensor<B, 3>,
bias: Option<Tensor<B, 1>>,
stride: usize,
padding: usize,
dilation: usize,
options: ConvOptions<1>,
) -> Tensor<B, 3>
where
B: Backend,
@ -24,9 +26,7 @@ where
x.primitive,
weight.primitive,
bias.map(|b| b.primitive),
stride,
padding,
dilation,
options,
))
}
@ -35,9 +35,7 @@ pub fn conv2d<B>(
x: Tensor<B, 4>,
weight: Tensor<B, 4>,
bias: Option<Tensor<B, 1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
options: ConvOptions<2>,
) -> Tensor<B, 4>
where
B: Backend,
@ -46,9 +44,7 @@ where
x.primitive,
weight.primitive,
bias.map(|b| b.primitive),
stride,
padding,
dilation,
options,
))
}
@ -57,10 +53,7 @@ pub fn conv_transpose1d<B>(
x: Tensor<B, 3>,
weight: Tensor<B, 3>,
bias: Option<Tensor<B, 1>>,
stride: usize,
padding: usize,
padding_out: usize,
dilation: usize,
options: ConvTransposeOptions<1>,
) -> Tensor<B, 3>
where
B: Backend,
@ -69,10 +62,7 @@ where
x.primitive,
weight.primitive,
bias.map(|b| b.primitive),
stride,
padding,
padding_out,
dilation,
options,
))
}
@ -81,10 +71,7 @@ pub fn conv_transpose2d<B>(
x: Tensor<B, 4>,
weight: Tensor<B, 4>,
bias: Option<Tensor<B, 1>>,
stride: [usize; 2],
padding: [usize; 2],
padding_out: [usize; 2],
dilation: [usize; 2],
options: ConvTransposeOptions<2>,
) -> Tensor<B, 4>
where
B: Backend,
@ -93,10 +80,7 @@ where
x.primitive,
weight.primitive,
bias.map(|b| b.primitive),
stride,
padding,
padding_out,
dilation,
options,
))
}

View File

@ -30,6 +30,25 @@ pub struct Conv1dBackward<B: Backend> {
pub bias_grad: Option<B::TensorPrimitive<1>>,
}
/// Convolution options.
#[derive(new, Debug, Clone)]
pub struct ConvOptions<const N: usize> {
pub stride: [usize; N],
pub padding: [usize; N],
pub dilation: [usize; N],
pub groups: usize,
}
/// Transposed convolution options.
#[derive(new, Debug, Clone)]
pub struct ConvTransposeOptions<const N: usize> {
pub stride: [usize; N],
pub padding: [usize; N],
pub padding_out: [usize; N],
pub dilation: [usize; N],
pub groups: usize,
}
pub trait ModuleOps<B: Backend> {
fn embedding(
weights: B::TensorPrimitive<2>,
@ -51,9 +70,7 @@ pub trait ModuleOps<B: Backend> {
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
options: ConvOptions<2>,
) -> B::TensorPrimitive<4>;
/// Two dimensional transposed convolution.
///
@ -66,10 +83,7 @@ pub trait ModuleOps<B: Backend> {
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
padding: [usize; 2],
padding_out: [usize; 2],
dilation: [usize; 2],
options: ConvTransposeOptions<2>,
) -> B::TensorPrimitive<4>;
/// Backward pass for the [conv2d](ModuleOps::conv2d) operation.
@ -77,12 +91,10 @@ pub trait ModuleOps<B: Backend> {
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: B::TensorPrimitive<4>,
options: ConvOptions<2>,
) -> Conv2dBackward<B> {
conv::conv2d_backward(x, weight, bias, stride, padding, dilation, output_grad)
conv::conv2d_backward(x, weight, bias, output_grad, options)
}
/// One dimensional convolution.
///
@ -95,11 +107,9 @@ pub trait ModuleOps<B: Backend> {
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
dilation: usize,
options: ConvOptions<1>,
) -> B::TensorPrimitive<3> {
conv::conv1d_from_conv2d::<B>(x, weight, bias, stride, padding, dilation)
conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
}
/// One dimensional transposed convolution.
///
@ -112,32 +122,19 @@ pub trait ModuleOps<B: Backend> {
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
padding_out: usize,
dilation: usize,
options: ConvTransposeOptions<1>,
) -> B::TensorPrimitive<3> {
conv::conv_transpose1d_from_conv_transpose2d::<B>(
x,
weight,
bias,
stride,
padding,
padding_out,
dilation,
)
conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
}
/// Backward pass for the [conv1d](ModuleOps::conv1d) operation.
fn conv1d_backward(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
dilation: usize,
output_grad: B::TensorPrimitive<3>,
options: ConvOptions<1>,
) -> Conv1dBackward<B> {
conv::conv1d_backward(x, weight, bias, stride, padding, dilation, output_grad)
conv::conv1d_backward(x, weight, bias, output_grad, options)
}
/// Two dimensional max pooling.
///

View File

@ -1,4 +1,4 @@
use super::{Conv1dBackward, Conv2dBackward};
use super::{Conv1dBackward, Conv2dBackward, ConvOptions, ConvTransposeOptions};
use crate::{backend::Backend, Shape};
use libm::ceilf;
@ -31,43 +31,26 @@ pub fn calculate_conv_output_size(
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
}
fn calculate_padding_out(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
size_in: usize,
size_out: usize,
) -> usize {
if stride <= 1 {
return 0;
}
let out = 1 + libm::ceil(
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64,
) as usize;
i64::max(0, out as i64 - size_out as i64) as usize
}
/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass using convolutions.
pub(crate) fn conv1d_backward<B: Backend>(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
dilation: usize,
output_grad: B::TensorPrimitive<3>,
options: ConvOptions<1>,
) -> Conv1dBackward<B> {
let [batch_size, channels_in, length_in] = B::shape(&x).dims;
let weight_shape = B::shape(&weight);
let weight_device = B::device(&weight);
let [batch_size, _, length_in] = B::shape(&x).dims;
let [_batch_size, channels_out, length_out] = B::shape(&output_grad).dims;
let [_, _, kernel_size] = B::shape(&weight).dims;
let [_, _, kernel_size] = weight_shape.dims;
let padding_out = calculate_padding_out(
kernel_size,
stride,
padding,
dilation,
options.stride[0],
options.padding[0],
options.dilation[0],
length_in,
length_out,
);
@ -76,36 +59,30 @@ pub(crate) fn conv1d_backward<B: Backend>(
output_grad.clone(),
weight,
None,
stride,
padding,
padding_out,
dilation,
ConvTransposeOptions::new(
options.stride,
options.padding,
[padding_out],
options.dilation,
options.groups,
),
);
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let weight_grad_swapped = B::conv1d(
x_swapped,
output_grad_swapped.clone(),
None,
dilation,
padding,
stride,
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
if B::shape(&weight_grad) != Shape::new([channels_out, channels_in, kernel_size]) {
weight_grad = B::index(
weight_grad,
[0..channels_out, 0..channels_in, 0..kernel_size],
);
}
let weight_grad = match options.groups == 1 {
true => conv1d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
false => conv1d_weight_grad_groups::<B>(
x,
B::zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
};
Conv1dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = output_grad_swapped;
let grad = B::swap_dims(output_grad, 0, 1);
let grad = B::reshape(grad, Shape::new([channels_out, batch_size * length_out]));
let grad = B::sum_dim(grad, 1);
@ -119,28 +96,29 @@ pub(crate) fn conv2d_backward<B: Backend>(
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
output_grad: B::TensorPrimitive<4>,
options: ConvOptions<2>,
) -> Conv2dBackward<B> {
let [batch_size, channels_in, height_in, width_in] = B::shape(&x).dims;
let [_batch_size, channels_out, height_out, width_out] = B::shape(&output_grad).dims;
let [_, _, kernel_size_1, kernel_size_2] = B::shape(&weight).dims;
let weight_shape = B::shape(&weight);
let weight_device = B::device(&weight);
let [batch_size, _channels_in, height_in, width_in] = B::shape(&x).dims;
let [_, _, height_out, width_out] = B::shape(&output_grad).dims;
let [channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims;
let padding_1_out = calculate_padding_out(
kernel_size_1,
stride[0],
padding[0],
dilation[0],
options.stride[0],
options.padding[0],
options.dilation[0],
height_in,
height_out,
);
let padding_2_out = calculate_padding_out(
kernel_size_2,
stride[1],
padding[1],
dilation[1],
options.stride[1],
options.padding[1],
options.dilation[1],
width_in,
width_out,
);
@ -149,43 +127,30 @@ pub(crate) fn conv2d_backward<B: Backend>(
output_grad.clone(),
weight,
None,
stride,
padding,
ConvTransposeOptions::new(
options.stride,
options.padding,
[padding_1_out, padding_2_out],
dilation,
options.dilation,
options.groups,
),
);
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let weight_grad_swapped = B::conv2d(
x_swapped,
output_grad_swapped.clone(),
None,
dilation,
padding,
stride,
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
if B::shape(&weight_grad)
!= Shape::new([channels_out, channels_in, kernel_size_1, kernel_size_2])
{
weight_grad = B::index(
weight_grad,
[
0..channels_out,
0..channels_in,
0..kernel_size_1,
0..kernel_size_2,
],
);
}
let weight_grad = match options.groups == 1 {
true => conv2d_weight_grad_no_groups::<B>(x, output_grad.clone(), weight_shape, options),
false => conv2d_weight_grad_groups::<B>(
x,
B::zeros(weight_shape, &weight_device),
output_grad.clone(),
options,
),
};
Conv2dBackward::new(
x_grad,
weight_grad,
bias.map(|b| {
let grad = output_grad_swapped;
let grad = B::swap_dims(output_grad, 0, 1);
let grad = B::reshape(
grad,
Shape::new([channels_out, batch_size * height_out * width_out]),
@ -202,20 +167,28 @@ pub(crate) fn conv1d_from_conv2d<B: Backend>(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
dilation: usize,
options: ConvOptions<1>,
) -> B::TensorPrimitive<3> {
let [channels_out, _channels_in, kernel_size] = B::shape(&weight).dims;
let [batch_size, channels_in, length_in] = B::shape(&x).dims;
let weight = B::reshape(
weight,
Shape::new([channels_out, channels_in, kernel_size, 1]),
Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]),
);
let x = B::reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));
let tensor = B::conv2d(x, weight, bias, [stride, 1], [padding, 0], [dilation, 1]);
let tensor = B::conv2d(
x,
weight,
bias,
ConvOptions::new(
[options.stride[0], 1],
[options.padding[0], 0],
[options.dilation[0], 1],
options.groups,
),
);
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
}
@ -225,10 +198,7 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
stride: usize,
padding: usize,
padding_out: usize,
dilation: usize,
options: ConvTransposeOptions<1>,
) -> B::TensorPrimitive<3> {
let [channels_in, channels_out, kernel_size] = B::shape(&weight).dims;
let [batch_size, _channels_in, length_in] = B::shape(&x).dims;
@ -243,15 +213,174 @@ pub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(
x,
weight,
bias,
[stride, 1],
[padding, 0],
[padding_out, 0],
[dilation, 1],
ConvTransposeOptions::new(
[options.stride[0], 1],
[options.padding[0], 0],
[options.padding_out[0], 0],
[options.dilation[0], 1],
options.groups,
),
);
let [batch_size, channels_out, height_out, _weight_out] = B::shape(&tensor).dims;
B::reshape(tensor, Shape::from([batch_size, channels_out, height_out]))
}
fn conv1d_weight_grad_groups<B: Backend>(
x: B::TensorPrimitive<3>,
mut weight_grad: B::TensorPrimitive<3>,
output_grad: B::TensorPrimitive<3>,
options: ConvOptions<1>,
) -> B::TensorPrimitive<3> {
let [channels_out, increment_ci, kernel_size] = B::shape(&weight_grad).dims;
let increment_co = channels_out / options.groups;
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
for g in 0..options.groups {
let start_idx_ci = g * increment_ci;
let end_idx_ci = (g + 1) * increment_ci;
let start_idx_co = g * increment_co;
let end_idx_co = (g + 1) * increment_co;
let x = B::index(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::index(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let mut weight_grad_tmp = B::conv1d(
x,
grad,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
weight_grad = B::index_assign(
weight_grad,
[start_idx_co..end_idx_co, 0..increment_ci, 0..kernel_size],
weight_grad_tmp,
);
}
weight_grad
}
fn conv2d_weight_grad_groups<B: Backend>(
x: B::TensorPrimitive<4>,
mut weight_grad: B::TensorPrimitive<4>,
output_grad: B::TensorPrimitive<4>,
options: ConvOptions<2>,
) -> B::TensorPrimitive<4> {
let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = B::shape(&weight_grad).dims;
let increment_co = channels_out / options.groups;
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
for g in 0..options.groups {
let start_idx_ci = g * increment_ci;
let end_idx_ci = (g + 1) * increment_ci;
let start_idx_co = g * increment_co;
let end_idx_co = (g + 1) * increment_co;
let x = B::index(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
let grad = B::index(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
let mut weight_grad_tmp = B::conv2d(
x,
grad,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
weight_grad = B::index_assign(
weight_grad,
[
start_idx_co..end_idx_co,
0..increment_ci,
0..kernel_size_1,
0..kernel_size_2,
],
weight_grad_tmp,
);
}
weight_grad
}
fn conv1d_weight_grad_no_groups<B: Backend>(
x: B::TensorPrimitive<3>,
output_grad: B::TensorPrimitive<3>,
weight_shape: Shape<3>,
options: ConvOptions<1>,
) -> B::TensorPrimitive<3> {
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let weight_grad_swapped = B::conv1d(
x_swapped,
output_grad_swapped,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
if B::shape(&weight_grad) != weight_shape {
weight_grad = B::index(
weight_grad,
[
0..weight_shape.dims[0],
0..weight_shape.dims[1],
0..weight_shape.dims[2],
],
);
}
weight_grad
}
fn conv2d_weight_grad_no_groups<B: Backend>(
x: B::TensorPrimitive<4>,
output_grad: B::TensorPrimitive<4>,
weight_shape: Shape<4>,
options: ConvOptions<2>,
) -> B::TensorPrimitive<4> {
let x_swapped = B::swap_dims(x, 0, 1);
let output_grad_swapped = B::swap_dims(output_grad, 0, 1);
let weight_grad_swapped = B::conv2d(
x_swapped,
output_grad_swapped,
None,
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
);
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
if B::shape(&weight_grad) != weight_shape {
weight_grad = B::index(
weight_grad,
[
0..weight_shape.dims[0],
0..weight_shape.dims[1],
0..weight_shape.dims[2],
0..weight_shape.dims[3],
],
);
}
weight_grad
}
fn calculate_padding_out(
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
size_in: usize,
size_out: usize,
) -> usize {
if stride <= 1 {
return 0;
}
let out = 1 + libm::ceil(
(size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64,
) as usize;
i64::max(0, out as i64 - size_out as i64) as usize
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -2,32 +2,26 @@
mod tests {
use super::*;
use burn_tensor::module::conv1d;
use burn_tensor::{Data, Tensor};
use burn_tensor::ops::ConvOptions;
use burn_tensor::{Data, Shape, Tensor};
#[test]
fn test_conv1d_simple() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 3,
channels_out: 3,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
length: 6,
groups: 1,
length: 4,
};
test.assert_output(TestTensor::from_floats([
[
[7., 10., 10., 10., 10., 7.],
[7., 10., 10., 10., 10., 7.],
[7., 10., 10., 10., 10., 7.],
],
[
[7., 10., 10., 10., 10., 7.],
[7., 10., 10., 10., 10., 7.],
[7., 10., 10., 10., 10., 7.],
],
[[43., 67., 82., 49.], [104., 176., 227., 158.]],
[[139., 187., 202., 113.], [392., 584., 635., 414.]],
]));
}
@ -41,12 +35,33 @@ mod tests {
padding: 1,
stride: 1,
dilation: 2,
groups: 1,
length: 4,
};
test.assert_output(TestTensor::from_floats([
[[5., 5.], [5., 5.]],
[[5., 5.], [5., 5.]],
[[62., 38.], [159., 111.]],
[[158., 102.], [447., 367.]],
]));
}
#[test]
fn test_conv1d_groups() {
let test = Conv1dTestCase {
batch_size: 2,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
stride: 1,
dilation: 1,
groups: 2,
length: 4,
};
test.assert_output(TestTensor::from_floats([
[[2., 5., 8., 3.], [42., 63., 75., 47.]],
[[26., 29., 32., 11.], [114., 159., 171., 103.]],
]));
}
@ -60,22 +75,13 @@ mod tests {
padding: 1,
stride: 2,
dilation: 1,
length: 9,
groups: 1,
length: 4,
};
test.assert_output(TestTensor::from_floats([
[
[7., 10., 10., 10., 7.],
[7., 10., 10., 10., 7.],
[7., 10., 10., 10., 7.],
[7., 10., 10., 10., 7.],
],
[
[7., 10., 10., 10., 7.],
[7., 10., 10., 10., 7.],
[7., 10., 10., 10., 7.],
[7., 10., 10., 10., 7.],
],
[[171., 294.], [415., 781.], [659., 1268.], [903., 1755.]],
[[495., 726.], [1387., 2185.], [2279., 3644.], [3171., 5103.]],
]));
}
@ -87,21 +93,40 @@ mod tests {
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
length: usize,
}
impl Conv1dTestCase {
fn assert_output(self, y: TestTensor<3>) {
let weights = TestTensor::ones([self.channels_out, self.channels_in, self.kernel_size]);
let bias = TestTensor::ones([self.channels_out]);
let x = TestTensor::ones([self.batch_size, self.channels_in, self.length]);
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.groups,
self.kernel_size,
]);
let weight = TestTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements())
.reshape(shape_weight)
.into_data()
.convert(),
);
let bias = TestTensor::from_data(
TestTensorInt::arange(0..self.channels_out)
.into_data()
.convert(),
);
let x = TestTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
);
let output = conv1d(
x,
weights,
weight,
Some(bias),
self.stride,
self.padding,
self.dilation,
ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),
);
y.to_data().assert_approx_eq(&output.into_data(), 3);

View File

@ -2,14 +2,15 @@
mod tests {
use super::*;
use burn_tensor::module::conv2d;
use burn_tensor::{Data, Tensor};
use burn_tensor::ops::ConvOptions;
use burn_tensor::{Data, Shape, Tensor};
#[test]
fn test_conv2d_simple() {
let test = Conv2dTestCase {
batch_size: 2,
channels_in: 3,
channels_out: 3,
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
@ -18,64 +19,54 @@ mod tests {
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
height: 6,
width: 6,
groups: 1,
height: 4,
width: 4,
};
test.assert_output(TestTensor::from_floats([
test.assert_output(TestTensor::from_floats([[
[
[
[13., 19., 19., 19., 19., 13.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[13., 19., 19., 19., 19., 13.],
[1196., 1796., 1916., 1264.],
[1881., 2793., 2946., 1923.],
[2313., 3405., 3558., 2307.],
[1424., 2072., 2156., 1380.],
],
[
[13., 19., 19., 19., 19., 13.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[13., 19., 19., 19., 19., 13.],
[2709., 4173., 4509., 3065.],
[4582., 7006., 7483., 5056.],
[5878., 8914., 9391., 6304.],
[4089., 6177., 6477., 4333.],
],
]]));
}
#[test]
fn test_conv2d_groups() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 5,
width: 5,
};
test.assert_output(TestTensor::from_floats([[
[[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]],
[
[13., 19., 19., 19., 19., 13.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[13., 19., 19., 19., 19., 13.],
[3724., 3841., 3958.],
[4309., 4426., 4543.],
[4894., 5011., 5128.],
],
],
[
[
[13., 19., 19., 19., 19., 13.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[13., 19., 19., 19., 19., 13.],
],
[
[13., 19., 19., 19., 19., 13.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[13., 19., 19., 19., 19., 13.],
],
[
[13., 19., 19., 19., 19., 13.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[19., 28., 28., 28., 28., 19.],
[13., 19., 19., 19., 19., 13.],
],
],
]));
]]));
}
#[test]
@ -92,22 +83,23 @@ mod tests {
stride_2: 3,
dilation_1: 1,
dilation_2: 2,
groups: 1,
height: 4,
width: 5,
};
test.assert_output(TestTensor::from_floats([
[
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[1845., 3789., 1926.], [3210., 6465., 3228.]],
[[4276., 9082., 4789.], [8071., 16834., 8737.]],
[[6707., 14375., 7652.], [12932., 27203., 14246.]],
[[9138., 19668., 10515.], [17793., 37572., 19755.]],
],
[
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[7., 13., 7.], [10., 19., 10.]],
[[5445., 10629., 5166.], [8070., 15645., 7548.]],
[[14356., 28882., 14509.], [22651., 45454., 22777.]],
[[23267., 47135., 23852.], [37232., 75263., 38006.]],
[[32178., 65388., 33195.], [51813., 105072., 53235.]],
],
]));
}
@ -124,27 +116,47 @@ mod tests {
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
groups: usize,
height: usize,
width: usize,
}
impl Conv2dTestCase {
fn assert_output(self, y: TestTensor<4>) {
let weights = TestTensor::ones([
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in,
self.channels_in / self.groups,
self.kernel_size_1,
self.kernel_size_2,
]);
let bias = TestTensor::ones([self.channels_out]);
let x = TestTensor::ones([self.batch_size, self.channels_in, self.height, self.width]);
let weight = TestTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements())
.reshape(shape_weight)
.into_data()
.convert(),
);
let bias = TestTensor::from_data(
TestTensorInt::arange(0..self.channels_out)
.into_data()
.convert(),
);
let x = TestTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements())
.reshape(shape_x)
.into_data()
.convert(),
);
let output = conv2d(
x,
weights,
weight,
Some(bias),
ConvOptions::new(
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.dilation_1, self.dilation_2],
self.groups,
),
);
y.to_data().assert_approx_eq(&output.into_data(), 3);

View File

@ -2,6 +2,7 @@
mod tests {
use super::*;
use burn_tensor::module::conv_transpose1d;
use burn_tensor::ops::ConvTransposeOptions;
use burn_tensor::{Data, Shape, Tensor};
#[test]
@ -15,6 +16,7 @@ mod tests {
padding_out: 0,
stride: 1,
dilation: 1,
groups: 1,
length: 4,
};
@ -35,6 +37,7 @@ mod tests {
padding_out: 1,
stride: 2,
dilation: 1,
groups: 1,
length: 4,
};
@ -55,6 +58,7 @@ mod tests {
padding_out: 0,
stride: 1,
dilation: 2,
groups: 1,
length: 4,
};
@ -64,6 +68,27 @@ mod tests {
]]));
}
#[test]
fn test_conv_transpose1d_groups() {
let test = ConvTranspose1dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size: 3,
padding: 1,
padding_out: 0,
stride: 1,
dilation: 1,
groups: 2,
length: 4,
};
test.assert_output(TestTensor::from_floats([[
[0., 1., 4., 7.],
[32., 59., 71., 59.],
]]));
}
struct ConvTranspose1dTestCase {
batch_size: usize,
channels_in: usize,
@ -73,13 +98,18 @@ mod tests {
padding_out: usize,
stride: usize,
dilation: usize,
groups: usize,
length: usize,
}
impl ConvTranspose1dTestCase {
fn assert_output(self, y: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);
let shape_weights = Shape::new([self.channels_in, self.channels_out, self.kernel_size]);
let shape_weights = Shape::new([
self.channels_in,
self.channels_out / self.groups,
self.kernel_size,
]);
let weights = TestTensor::from_data(
TestTensorInt::arange(0..shape_weights.num_elements())
.reshape(shape_weights)
@ -101,10 +131,13 @@ mod tests {
x,
weights,
Some(bias),
self.stride,
self.padding,
self.padding_out,
self.dilation,
ConvTransposeOptions::new(
[self.stride],
[self.padding],
[self.padding_out],
[self.dilation],
self.groups,
),
);
y.to_data().assert_approx_eq(&output.into_data(), 3);

View File

@ -2,6 +2,7 @@
mod tests {
use super::*;
use burn_tensor::module::conv_transpose2d;
use burn_tensor::ops::ConvTransposeOptions;
use burn_tensor::{Data, Shape, Tensor};
#[test]
@ -20,6 +21,7 @@ mod tests {
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 2,
width: 2,
};
@ -42,6 +44,7 @@ mod tests {
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
@ -84,6 +87,7 @@ mod tests {
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 2,
width: 2,
};
@ -112,6 +116,7 @@ mod tests {
stride_2: 1,
dilation_1: 2,
dilation_2: 2,
groups: 1,
height: 2,
width: 2,
};
@ -150,6 +155,7 @@ mod tests {
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 4,
width: 4,
};
@ -178,6 +184,94 @@ mod tests {
]]));
}
#[test]
fn test_conv_transpose2d_groups_2() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 2,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 1,
padding_2: 1,
padding_out_1: 0,
padding_out_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 2,
width: 2,
};
test.assert_output(TestTensor::from_floats([[
[[5., 11.], [23., 29.]],
[[236., 258.], [302., 324.]],
]]));
}
#[test]
fn test_conv_transpose2d_groups_different_channels() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 6,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
padding_out_1: 0,
padding_out_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 2,
width: 2,
};
test.assert_output(TestTensor::from_floats([[
[
[0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00],
[0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01],
[6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01],
[1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01],
],
[
[1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01],
[1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01],
[2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01],
[3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01],
],
[
[2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01],
[3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01],
[4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01],
[5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01],
],
[
[1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02],
[2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02],
[3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02],
[2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02],
],
[
[1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02],
[3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02],
[4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02],
[2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02],
],
[
[1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02],
[4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02],
[4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02],
[3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02],
],
]]));
}
struct ConvTranspose2dTestCase {
batch_size: usize,
channels_in: usize,
@ -192,6 +286,7 @@ mod tests {
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
groups: usize,
height: usize,
width: usize,
}
@ -201,7 +296,7 @@ mod tests {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
let shape_weights = Shape::new([
self.channels_in,
self.channels_out,
self.channels_out / self.groups,
self.kernel_size_1,
self.kernel_size_2,
]);
@ -226,10 +321,13 @@ mod tests {
x,
weights,
Some(bias),
ConvTransposeOptions::new(
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.padding_out_1, self.padding_out_2],
[self.dilation_1, self.dilation_2],
self.groups,
),
);
y.to_data().assert_approx_eq(&output.into_data(), 3);