refactor: transpose (#71)

This commit is contained in:
Nathaniel Simard 2022-11-05 20:18:31 -04:00 committed by GitHub
parent ad23898d23
commit e6541298b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 115 additions and 159 deletions

View File

@ -15,7 +15,6 @@ mod pow;
mod precision;
mod reshape;
mod tensor;
mod transpose;
mod macros;
pub(crate) use base::*;

View File

@ -5,7 +5,7 @@ use crate::{
Backend,
},
graph::ops::{BinaryOps, BinaryOpsNodeState, UnaryOps, UnaryOpsNodeState},
ops::{Ones, TensorOps, TensorOpsTranspose},
ops::{Ones, TensorOps},
Data, Shape,
};
@ -399,7 +399,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
>,
) -> B::TensorPrimitive<D> {
let out_grad = state.output.grad();
let rhs = state.right.value().transpose();
let rhs = B::transpose(&state.right.value());
B::matmul(&out_grad, &rhs)
}
@ -412,7 +412,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
>,
) -> B::TensorPrimitive<D> {
let out_grad = state.output.grad();
let lhs = state.left.value().transpose();
let lhs = B::transpose(&state.left.value());
B::matmul(&lhs, &out_grad)
}
}
@ -447,4 +447,33 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
fn swap_dims<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
dim1: usize,
dim2: usize,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct SwapDimsBackward<B: Backend, const D: usize> {
_b: B,
dim1: usize,
dim2: usize,
}
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for SwapDimsBackward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
B::swap_dims(&state.output.grad(), self.dim2, self.dim1)
}
}
let output = B::swap_dims(tensor.tensor_ref(), dim1, dim2);
let ops = SwapDimsBackward::<B, D>::new(B::default(), dim1, dim2);
unary_ops_wrapper(tensor.node.clone(), output, ops)
}
}

View File

@ -1,93 +0,0 @@
use crate::graph::ops::{UnaryOps, UnaryOpsNodeState};
use crate::tensor::backend::autodiff::ADTensor;
use crate::tensor::backend::Backend;
use crate::tensor::ops::*;
use crate::{execute_ops, register_ops};
register_ops!(
ops UnaryOps,
name ADTensorTransposeOps,
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
state.output.grad().transpose()
},
);
#[derive(Debug)]
struct DimState {
dim1: usize,
dim2: usize,
}
register_ops!(
ops UnaryOps,
name ADTensorSwapDimOps state DimState,
partial |dims: &DimState, state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
state.output.grad().swap_dims(dims.dim2, dims.dim1)
},
);
impl<B: Backend, const D: usize> TensorOpsTranspose<B::Elem, D> for ADTensor<D, B> {
fn transpose(&self) -> Self {
execute_ops!(
input self.node.clone(),
out TensorOpsTranspose::transpose(&self.tensor()),
ops ADTensorTransposeOps::<B, D>::new(),
)
}
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self {
execute_ops!(
input self.node.clone(),
out TensorOpsTranspose::swap_dims(&self.tensor(), dim1, dim2),
ops ADTensorSwapDimOps::<B, D>::new(DimState { dim1, dim2 }),
)
}
}
#[cfg(test)]
mod tests {
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};
#[test]
fn should_diff_transpose() {
let data_1 = Data::<f64, 2>::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = Data::<f64, 2>::from([[4.0, 7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2.transpose());
let tensor_4 = tensor_3.transpose();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(grad_1.to_data(), Data::from([[6.0, 10.0], [6.0, 10.0]]));
assert_eq!(grad_2.to_data(), Data::from([[3.0, 10.0], [3.0, 10.0]]));
}
#[test]
fn should_diff_swap_dims() {
let data_1 = Data::<f64, 3>::from([[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]]);
let data_2 = Data::<f64, 3>::from([[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2.swap_dims(0, 2));
let tensor_4 = tensor_3.matmul(&tensor_2.swap_dims(1, 2));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(
grad_1.to_data(),
Data::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]])
);
assert_eq!(
grad_2.to_data(),
Data::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]])
);
}
}

View File

@ -22,7 +22,6 @@ pub trait Backend:
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
type TensorPrimitive<const D: usize>: std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
+ TensorOpsTranspose<Self::Elem, D>
+ TensorOpsDetach<Self::Elem, D>
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>

View File

@ -12,4 +12,3 @@ mod mask;
mod pow;
mod precision;
mod reshape;
mod transpose;

View File

@ -1,26 +0,0 @@
use crate::{
tensor::{backend::ndarray::NdArrayTensor, ops::*},
NdArrayElement,
};
impl<P, const D: usize> TensorOpsTranspose<P, D> for NdArrayTensor<P, D>
where
P: Default + Clone + std::fmt::Debug + NdArrayElement,
{
fn transpose(&self) -> Self {
self.swap_dims(D - 2, D - 1)
}
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self {
let mut shape = self.shape;
let dim1_new = shape.dims[dim2];
let dim2_new = shape.dims[dim1];
shape.dims[dim1] = dim1_new;
shape.dims[dim2] = dim2_new;
let mut array = self.array.clone();
array.swap_axes(dim1, dim2);
Self { array, shape }
}
}

View File

@ -178,4 +178,21 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
Self::mul_scalar(tensor, &(-1f32).to_elem::<E>())
}
fn swap_dims<const D: usize>(
tensor: &NdArrayTensor<E, D>,
dim1: usize,
dim2: usize,
) -> NdArrayTensor<E, D> {
let mut shape = tensor.shape;
let dim1_new = shape.dims[dim2];
let dim2_new = shape.dims[dim1];
shape.dims[dim1] = dim1_new;
shape.dims[dim2] = dim2_new;
let mut array = tensor.array.clone();
array.swap_axes(dim1, dim2);
NdArrayTensor { array, shape }
}
}

View File

@ -12,4 +12,3 @@ mod mask;
mod pow;
mod precision;
mod reshape;
mod transpose;

View File

@ -1,26 +0,0 @@
use crate::tensor::{backend::tch::TchTensor, ops::*, Shape};
impl<P: tch::kind::Element, const D: usize> TensorOpsTranspose<P, D> for TchTensor<P, D> {
fn transpose(&self) -> Self {
let tensor = self.tensor.transpose(-2, -1);
let kind = self.kind.clone();
let shape = Shape::from(tensor.size());
Self {
kind,
tensor,
shape,
}
}
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self {
let tensor = self.tensor.transpose(dim1 as i64, dim2 as i64);
let kind = self.kind.clone();
let shape = Shape::from(tensor.size());
Self {
kind,
tensor,
shape,
}
}
}

View File

@ -127,6 +127,15 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn neg<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
Self::mul_scalar(tensor, &(-1f32).to_elem::<E>())
}
fn swap_dims<const D: usize>(
tensor: &TchTensor<E, D>,
dim1: usize,
dim2: usize,
) -> TchTensor<E, D> {
let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);
to_tensor(tensor)
}
}
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {

View File

@ -197,7 +197,7 @@ where
///
/// If the tensor is of 1 dimension or less.
pub fn transpose(&self) -> Self {
Self::new(self.value.transpose())
Self::new(B::transpose(&self.value))
}
/// Swap two dimensions.
@ -206,7 +206,7 @@ where
///
/// If the dimensions exceed the shape of than the tensor.
pub fn swap_dims(&self, dim1: usize, dim2: usize) -> Self {
Self::new(self.value.swap_dims(dim1, dim2))
Self::new(B::swap_dims(&self.value, dim1, dim2))
}
/// Applies the matrix multiplication operation.

View File

@ -103,11 +103,14 @@ pub trait TensorOps<B: Backend> {
rhs: &B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
fn neg<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
}
pub trait TensorOpsTranspose<E, const D: usize> {
fn transpose(&self) -> Self;
fn swap_dims(&self, dim1: usize, dim2: usize) -> Self;
fn transpose<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
Self::swap_dims(tensor, D - 2, D - 1)
}
fn swap_dims<const D: usize>(
tensor: &B::TensorPrimitive<D>,
dim1: usize,
dim2: usize,
) -> B::TensorPrimitive<D>;
}
pub trait TensorOpsReshape<B: Backend, const D: usize> {

View File

@ -7,3 +7,4 @@ mod mul;
mod neg;
mod softmax;
mod sub;
mod transpose;

View File

@ -0,0 +1,46 @@
use crate::tensor::TestADTensor;
use burn_tensor::Data;
#[test]
fn should_diff_transpose() {
let data_1 = Data::<f32, 2>::from([[1.0, 7.0], [2.0, 3.0]]);
let data_2 = Data::<f32, 2>::from([[4.0, 7.0], [2.0, 3.0]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2.transpose());
let tensor_4 = tensor_3.transpose();
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(grad_1.to_data(), Data::from([[6.0, 10.0], [6.0, 10.0]]));
assert_eq!(grad_2.to_data(), Data::from([[3.0, 10.0], [3.0, 10.0]]));
}
#[test]
fn should_diff_swap_dims() {
let data_1 = Data::<f32, 3>::from([[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]]);
let data_2 = Data::<f32, 3>::from([[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]]);
let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);
let tensor_3 = tensor_1.matmul(&tensor_2.swap_dims(0, 2));
let tensor_4 = tensor_3.matmul(&tensor_2.swap_dims(1, 2));
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(
grad_1.to_data(),
Data::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]])
);
assert_eq!(
grad_2.to_data(),
Data::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]])
);
}