Fix concat backward with more than 1 dim (#402)

This commit is contained in:
Nathaniel Simard 2023-06-15 09:18:15 -04:00 committed by GitHub
parent d57ca96695
commit 71d7ebbb21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 22 deletions

View File

@ -1252,6 +1252,9 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
#[derive(new, Debug)]
struct CatStep<B: Backend, const D: usize> {
nodes: Vec<Option<NodeRef>>,
// The dimension of each tensor along the dim dimension.
// This indicates the number of dimension concatenated for each tensor.
dim_sizes: Vec<usize>,
output: NodeRef,
phantom: PhantomData<B>,
dim: usize,
@ -1263,13 +1266,16 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
let indexes: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect();
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
let mut current_index = 0;
self.nodes
.into_iter()
.enumerate()
.filter_map(|(i, node)| node.map(|node| (i, node)))
.for_each(|(i, node)| {
.zip(self.dim_sizes.into_iter())
.filter_map(|(node, dim_size)| node.map(|node| (node, dim_size)))
.for_each(|(node, dim_size)| {
let mut indexes = indexes.clone();
indexes[self.dim] = i..i + 1;
indexes[self.dim] = current_index..dim_size + current_index;
current_index += dim_size;
grads.register::<B, D>(node, B::index(grad.clone(), indexes));
});
}
@ -1282,8 +1288,10 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
let mut nodes = Vec::with_capacity(tensors.len());
let mut graphs = Vec::with_capacity(tensors.len());
let mut primitives = Vec::with_capacity(tensors.len());
let mut dim_sizes = Vec::with_capacity(tensors.len());
tensors.into_iter().for_each(|tensor| {
dim_sizes.push(B::shape(&tensor.primitive).dims[dim]);
nodes.push(tensor.node);
primitives.push(tensor.primitive);
graphs.push(tensor.graph);
@ -1302,9 +1310,10 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
.map(|node| node.clone_if_require_grad())
.collect::<Vec<_>>();
let ops = CatStep::<B, D>::new(nodes, output.node.clone(), dim);
let ops = CatStep::<B, D>::new(nodes, dim_sizes, output.node.clone(), dim);
output.register_step(ops)
}
fn max_dim<const D: usize>(tensor: ADTensor<B, D>, dim: usize) -> ADTensor<B, D> {
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => {

View File

@ -5,11 +5,8 @@ mod tests {
#[test]
fn should_diff_cat() {
let data_1 = Data::<_, 2>::from([[2.0, -1.0], [5.0, 2.0]]);
let data_2 = Data::<_, 2>::from([[5.0, 4.0], [-1.0, 4.0]]);
let tensor_1 = TestADTensor::from_data(data_1).require_grad();
let tensor_2 = TestADTensor::from_data(data_2).require_grad();
let tensor_1 = TestADTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad();
let tensor_2 = TestADTensor::from_data([[5.0, 4.0], [-1.0, 4.0]]).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();
@ -21,26 +18,21 @@ mod tests {
let mut tensor_2_list = Vec::new();
for i in 0..2 {
tensor_1_list.push(tensor_1.clone().index([i..i + 1]).detach().require_grad());
tensor_2_list.push(tensor_2.clone().index([i..i + 1]).detach().require_grad());
tensor_1_list.push(tensor_1.clone().index([i..i + 1]));
tensor_2_list.push(tensor_2.clone().index([i..i + 1]));
}
let tensor_1_cat = TestADTensor::cat(tensor_1_list.clone(), 0);
let tensor_2_cat = TestADTensor::cat(tensor_2_list.clone(), 0);
let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());
let grads_cat = tensor_3_cat.backward();
let grads = tensor_3_cat.backward();
let grad = |tensor: Option<&TestADTensor<2, Float>>| {
tensor
.map(|tensor| tensor.grad(&grads_cat).unwrap())
.unwrap()
};
let grad_1_index_1 = grad(tensor_1_list.get(0));
let grad_1_index_2 = grad(tensor_1_list.get(1));
let grad_1_index_1 = tensor_1.grad(&grads).unwrap().index([0..1]);
let grad_1_index_2 = tensor_1.grad(&grads).unwrap().index([1..2]);
let grad_2_index_1 = grad(tensor_2_list.get(0));
let grad_2_index_2 = grad(tensor_2_list.get(1));
let grad_2_index_1 = tensor_2.grad(&grads).unwrap().index([0..1]);
let grad_2_index_2 = tensor_2.grad(&grads).unwrap().index([1..2]);
grad_1
.clone()
@ -62,4 +54,23 @@ mod tests {
.to_data()
.assert_approx_eq(&grad_2_index_2.to_data(), 3);
}
#[test]
fn should_diff_cat_more_than_1_dim() {
let tensor_1 = TestADTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad();
let tensor_2 =
TestADTensor::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]]).require_grad();
// Concat a tensor [2, 2] with another tensor [3, 2] along dim 0.
// The resulting tensor should be [5, 2]
let tensor_3 = TestADTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0);
assert_eq!(tensor_3.dims(), [5, 2]);
let grads = tensor_3.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(tensor_1.dims(), grad_1.dims());
assert_eq!(tensor_2.dims(), grad_2.dims());
}
}