mirror of https://github.com/tracel-ai/burn.git
Fix select assign backward (#1739)
This commit is contained in:
parent
bd06b38fac
commit
a6e3b4e81e
|
@ -1058,7 +1058,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 2> for IndexSelectDimAssign<D> {
|
||||
type State = (usize, IntTensor<B, 1>, Shape<D>, Shape<D>, B::Device);
|
||||
type State = (usize, IntTensor<B, 1>);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
|
@ -1066,21 +1066,14 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
grads: &mut Gradients,
|
||||
_checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let (dim, indices, shape_lhs, shape_rhs, device) = ops.state;
|
||||
let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices));
|
||||
let (dim, indices) = ops.state;
|
||||
|
||||
binary::<B, D, D, D, _, _>(
|
||||
ops.parents,
|
||||
ops.node,
|
||||
grads,
|
||||
|grad| {
|
||||
let zeros = B::float_zeros(shape_lhs, &device);
|
||||
B::float_select_assign(grad, dim, indices_4lhs.unwrap(), zeros)
|
||||
},
|
||||
|grad| {
|
||||
let zeros = B::float_zeros(shape_rhs, &device);
|
||||
B::float_select_assign(zeros, dim, indices_4rhs.unwrap(), grad)
|
||||
},
|
||||
|grad| grad,
|
||||
|grad| B::float_select(grad, dim, indices),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -1098,13 +1091,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
dim,
|
||||
indices.clone(),
|
||||
B::float_shape(&tensor.primitive),
|
||||
B::float_shape(&value.primitive),
|
||||
B::float_device(&value.primitive),
|
||||
),
|
||||
(dim, indices.clone()),
|
||||
B::float_select_assign(tensor.primitive, dim, indices, value.primitive),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::float_select_assign(
|
||||
|
|
|
@ -54,4 +54,23 @@ mod tests {
|
|||
Data::from([[64., 64., 64.], [19., 19., 19.]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_assign_grad_different_shapes() {
|
||||
let device = Default::default();
|
||||
|
||||
let indices: Tensor<TestAutodiffBackend, 1, Int> = Tensor::from_ints([1], &device);
|
||||
let x: Tensor<TestAutodiffBackend, 2> = Tensor::ones([1, 1], &device).require_grad();
|
||||
let y = Tensor::ones([2, 1], &device).require_grad();
|
||||
|
||||
let w = y.clone().select_assign(0, indices, x.clone());
|
||||
let w = w.matmul(y.clone().transpose());
|
||||
|
||||
let grads = w.backward();
|
||||
let x_grad = x.grad(&grads).unwrap();
|
||||
let y_grad = y.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(x_grad.into_data(), Data::from([[2.0]]));
|
||||
assert_eq!(y_grad.into_data(), Data::from([[5.0], [5.0]]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -125,18 +125,12 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
pub fn select_assign<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
indices_tensor: TchTensor<i64, 1>,
|
||||
indices: TchTensor<i64, 1>,
|
||||
value: TchTensor<E, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
let mut indices = Vec::with_capacity(D);
|
||||
for _ in 0..D {
|
||||
indices.push(None);
|
||||
}
|
||||
indices[dim] = Some(indices_tensor.tensor);
|
||||
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.index_put_(&indices, &value.tensor, true),
|
||||
|tensor| tensor.index_put(&indices, &value.tensor, true),
|
||||
tensor.clone().unary_ops(
|
||||
|mut tensor| tensor.index_add_(dim as i64, &indices.tensor, &value.tensor),
|
||||
|tensor| tensor.index_add(dim as i64, &indices.tensor, &value.tensor),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue