mirror of https://github.com/tracel-ai/burn.git
Fix concat backward with more than 1 dim (#402)
This commit is contained in:
parent
d57ca96695
commit
71d7ebbb21
|
@ -1252,6 +1252,9 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
#[derive(new, Debug)]
|
||||
struct CatStep<B: Backend, const D: usize> {
|
||||
nodes: Vec<Option<NodeRef>>,
|
||||
// The dimension of each tensor along the dim dimension.
|
||||
// This indicates the number of dimension concatenated for each tensor.
|
||||
dim_sizes: Vec<usize>,
|
||||
output: NodeRef,
|
||||
phantom: PhantomData<B>,
|
||||
dim: usize,
|
||||
|
@ -1263,13 +1266,16 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
let indexes: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect();
|
||||
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
|
||||
|
||||
let mut current_index = 0;
|
||||
|
||||
self.nodes
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, node)| node.map(|node| (i, node)))
|
||||
.for_each(|(i, node)| {
|
||||
.zip(self.dim_sizes.into_iter())
|
||||
.filter_map(|(node, dim_size)| node.map(|node| (node, dim_size)))
|
||||
.for_each(|(node, dim_size)| {
|
||||
let mut indexes = indexes.clone();
|
||||
indexes[self.dim] = i..i + 1;
|
||||
indexes[self.dim] = current_index..dim_size + current_index;
|
||||
current_index += dim_size;
|
||||
grads.register::<B, D>(node, B::index(grad.clone(), indexes));
|
||||
});
|
||||
}
|
||||
|
@ -1282,8 +1288,10 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
let mut nodes = Vec::with_capacity(tensors.len());
|
||||
let mut graphs = Vec::with_capacity(tensors.len());
|
||||
let mut primitives = Vec::with_capacity(tensors.len());
|
||||
let mut dim_sizes = Vec::with_capacity(tensors.len());
|
||||
|
||||
tensors.into_iter().for_each(|tensor| {
|
||||
dim_sizes.push(B::shape(&tensor.primitive).dims[dim]);
|
||||
nodes.push(tensor.node);
|
||||
primitives.push(tensor.primitive);
|
||||
graphs.push(tensor.graph);
|
||||
|
@ -1302,9 +1310,10 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
.map(|node| node.clone_if_require_grad())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let ops = CatStep::<B, D>::new(nodes, output.node.clone(), dim);
|
||||
let ops = CatStep::<B, D>::new(nodes, dim_sizes, output.node.clone(), dim);
|
||||
output.register_step(ops)
|
||||
}
|
||||
|
||||
fn max_dim<const D: usize>(tensor: ADTensor<B, D>, dim: usize) -> ADTensor<B, D> {
|
||||
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
|
|
|
@ -5,11 +5,8 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn should_diff_cat() {
|
||||
let data_1 = Data::<_, 2>::from([[2.0, -1.0], [5.0, 2.0]]);
|
||||
let data_2 = Data::<_, 2>::from([[5.0, 4.0], [-1.0, 4.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1).require_grad();
|
||||
let tensor_2 = TestADTensor::from_data(data_2).require_grad();
|
||||
let tensor_1 = TestADTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad();
|
||||
let tensor_2 = TestADTensor::from_data([[5.0, 4.0], [-1.0, 4.0]]).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let grads = tensor_3.backward();
|
||||
|
@ -21,26 +18,21 @@ mod tests {
|
|||
let mut tensor_2_list = Vec::new();
|
||||
|
||||
for i in 0..2 {
|
||||
tensor_1_list.push(tensor_1.clone().index([i..i + 1]).detach().require_grad());
|
||||
tensor_2_list.push(tensor_2.clone().index([i..i + 1]).detach().require_grad());
|
||||
tensor_1_list.push(tensor_1.clone().index([i..i + 1]));
|
||||
tensor_2_list.push(tensor_2.clone().index([i..i + 1]));
|
||||
}
|
||||
|
||||
let tensor_1_cat = TestADTensor::cat(tensor_1_list.clone(), 0);
|
||||
let tensor_2_cat = TestADTensor::cat(tensor_2_list.clone(), 0);
|
||||
|
||||
let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());
|
||||
let grads_cat = tensor_3_cat.backward();
|
||||
let grads = tensor_3_cat.backward();
|
||||
|
||||
let grad = |tensor: Option<&TestADTensor<2, Float>>| {
|
||||
tensor
|
||||
.map(|tensor| tensor.grad(&grads_cat).unwrap())
|
||||
.unwrap()
|
||||
};
|
||||
let grad_1_index_1 = grad(tensor_1_list.get(0));
|
||||
let grad_1_index_2 = grad(tensor_1_list.get(1));
|
||||
let grad_1_index_1 = tensor_1.grad(&grads).unwrap().index([0..1]);
|
||||
let grad_1_index_2 = tensor_1.grad(&grads).unwrap().index([1..2]);
|
||||
|
||||
let grad_2_index_1 = grad(tensor_2_list.get(0));
|
||||
let grad_2_index_2 = grad(tensor_2_list.get(1));
|
||||
let grad_2_index_1 = tensor_2.grad(&grads).unwrap().index([0..1]);
|
||||
let grad_2_index_2 = tensor_2.grad(&grads).unwrap().index([1..2]);
|
||||
|
||||
grad_1
|
||||
.clone()
|
||||
|
@ -62,4 +54,23 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&grad_2_index_2.to_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_cat_more_than_1_dim() {
|
||||
let tensor_1 = TestADTensor::from_data([[2.0, -1.0], [5.0, 2.0]]).require_grad();
|
||||
let tensor_2 =
|
||||
TestADTensor::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]]).require_grad();
|
||||
|
||||
// Concat a tensor [2, 2] with another tensor [3, 2] along dim 0.
|
||||
// The resulting tensor should be [5, 2]
|
||||
let tensor_3 = TestADTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0);
|
||||
assert_eq!(tensor_3.dims(), [5, 2]);
|
||||
let grads = tensor_3.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(tensor_1.dims(), grad_1.dims());
|
||||
assert_eq!(tensor_2.dims(), grad_2.dims());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue