Fix select assign backward (#1739)

This commit is contained in:
Nathaniel Simard 2024-05-07 11:37:43 -04:00 committed by GitHub
parent bd06b38fac
commit a6e3b4e81e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 28 deletions

View File

@ -1058,7 +1058,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
impl<B: Backend, const D: usize> Backward<B, D, 2> for IndexSelectDimAssign<D> {
type State = (usize, IntTensor<B, 1>, Shape<D>, Shape<D>, B::Device);
type State = (usize, IntTensor<B, 1>);
fn backward(
self,
@ -1066,21 +1066,14 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let (dim, indices, shape_lhs, shape_rhs, device) = ops.state;
let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices));
let (dim, indices) = ops.state;
binary::<B, D, D, D, _, _>(
ops.parents,
ops.node,
grads,
|grad| {
let zeros = B::float_zeros(shape_lhs, &device);
B::float_select_assign(grad, dim, indices_4lhs.unwrap(), zeros)
},
|grad| {
let zeros = B::float_zeros(shape_rhs, &device);
B::float_select_assign(zeros, dim, indices_4rhs.unwrap(), grad)
},
|grad| grad,
|grad| B::float_select(grad, dim, indices),
);
}
}
@ -1098,13 +1091,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indices.clone(),
B::float_shape(&tensor.primitive),
B::float_shape(&value.primitive),
B::float_device(&value.primitive),
),
(dim, indices.clone()),
B::float_select_assign(tensor.primitive, dim, indices, value.primitive),
),
OpsKind::UnTracked(prep) => prep.finish(B::float_select_assign(

View File

@ -54,4 +54,23 @@ mod tests {
Data::from([[64., 64., 64.], [19., 19., 19.]])
);
}
#[test]
fn test_select_assign_grad_different_shapes() {
let device = Default::default();
let indices: Tensor<TestAutodiffBackend, 1, Int> = Tensor::from_ints([1], &device);
let x: Tensor<TestAutodiffBackend, 2> = Tensor::ones([1, 1], &device).require_grad();
let y = Tensor::ones([2, 1], &device).require_grad();
let w = y.clone().select_assign(0, indices, x.clone());
let w = w.matmul(y.clone().transpose());
let grads = w.backward();
let x_grad = x.grad(&grads).unwrap();
let y_grad = y.grad(&grads).unwrap();
assert_eq!(x_grad.into_data(), Data::from([[2.0]]));
assert_eq!(y_grad.into_data(), Data::from([[5.0], [5.0]]));
}
}

View File

@ -125,18 +125,12 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
pub fn select_assign<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indices_tensor: TchTensor<i64, 1>,
indices: TchTensor<i64, 1>,
value: TchTensor<E, D>,
) -> TchTensor<E, D> {
let mut indices = Vec::with_capacity(D);
for _ in 0..D {
indices.push(None);
}
indices[dim] = Some(indices_tensor.tensor);
tensor.unary_ops(
|mut tensor| tensor.index_put_(&indices, &value.tensor, true),
|tensor| tensor.index_put(&indices, &value.tensor, true),
tensor.clone().unary_ops(
|mut tensor| tensor.index_add_(dim as i64, &indices.tensor, &value.tensor),
|tensor| tensor.index_add(dim as i64, &indices.tensor, &value.tensor),
)
}