feat: easier scalar API (#46)

This commit is contained in:
Nathaniel Simard 2022-09-25 11:58:37 -04:00 committed by GitHub
parent fe5ed0dbb5
commit 60cd30a768
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 64 additions and 67 deletions

View File

@ -1,6 +1,6 @@
use crate::backend::Backend;
use crate::Tensor;
use crate::{ElementConversion, ElementPrecision, Precision};
use crate::{ElementPrecision, Precision};
/// Applies the rectified linear unit function.
pub fn relu<const D: usize, B: Backend>(tensor: &Tensor<B, D>) -> Tensor<B, D> {
@ -9,11 +9,9 @@ pub fn relu<const D: usize, B: Backend>(tensor: &Tensor<B, D>) -> Tensor<B, D> {
/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
pub fn gelu<const D: usize, B: Backend>(tensor: &Tensor<B, D>) -> Tensor<B, D> {
let x = tensor
.div_scalar(&2.0_f32.sqrt().to_elem())
.erf()
.add_scalar(&1.0_f32.to_elem());
tensor.mul(&x).mul_scalar(&0.5_f32.to_elem())
let x = tensor.div_scalar(2.0_f32.sqrt()).erf().add_scalar(1.0_f32);
tensor.mul(&x) / 2
}
/// Applies the softmax function.

View File

@ -77,7 +77,7 @@ mod tests {
let data = Data::from([2.0, 10.0]);
let tensor = TestADTensor::from_data(data);
let tensor_out = tensor.add_scalar(&5.0);
let tensor_out = tensor.add_scalar(5.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
@ -99,7 +99,7 @@ mod tests {
let tensor_4 = tensor_1.add(&tensor_2);
let tensor_5 = tensor_4
.add(&tensor_3)
.add_scalar(&5.0)
.add_scalar(5.0)
.add(&tensor_1)
.add(&tensor_2);
let tensor_6 = tensor_1.add(&tensor_5);

View File

@ -39,7 +39,7 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
let grad: Tensor<B, 1> = Tensor::new(grad);
let val = 1_f64 / self.state.num_elements() as f64;
let ones: Tensor<B, D> = Tensor::new(ones).mul_scalar(&B::Elem::from_elem(val));
let ones: Tensor<B, D> = Tensor::new(ones).mul_scalar(val);
ones.mul(&grad.unsqueeze()).value
}

View File

@ -85,7 +85,7 @@ mod tests {
let data = Data::from([1.0, 7.0]);
let tensor = TestADTensor::from_data(data);
let tensor_out = tensor.div_scalar(&4.0);
let tensor_out = tensor.div_scalar(4.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();

View File

@ -72,7 +72,7 @@ mod tests {
let data = Data::from([2.0, 5.0]);
let tensor = TestADTensor::from_data(data);
let tensor_out = tensor.mul_scalar(&4.0);
let tensor_out = tensor.mul_scalar(4.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();

View File

@ -70,7 +70,7 @@ mod tests {
fn should_diff_sub_scalar() {
let data = Data::from([2.0, 10.0]);
let tensor = TestADTensor::from_data(data);
let tensor_out = tensor.sub_scalar(&5.0);
let tensor_out = tensor.sub_scalar(5.0);
let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();
@ -90,7 +90,7 @@ mod tests {
let tensor_3 = TestADTensor::from_data(data_3);
let tensor_4 = tensor_1.sub(&tensor_2);
let tensor_5 = tensor_4.sub(&tensor_3).sub_scalar(&5.0);
let tensor_5 = tensor_4.sub(&tensor_3).sub_scalar(5.0);
let tensor_6 = tensor_1.sub(&tensor_5);
let grads = tensor_6.backward();

View File

@ -63,7 +63,7 @@ mod tests {
let tensor_3 = tensor_1.matmul(&tensor_2);
let tensor_4 = tensor_3.matmul(&tensor_1);
let tensor_5 = tensor_4.add_scalar(&17.0).add(&tensor_2);
let tensor_5 = tensor_4.add_scalar(17.0).add(&tensor_2);
let grads = tensor_5.backward();

View File

@ -4,9 +4,9 @@ use crate::tensor::backend::Backend;
use crate::tensor::ops::activation::*;
use crate::tensor::ops::*;
use crate::tensor::stats;
use crate::tensor::ElementConversion;
use crate::tensor::{Data, Distribution, Shape};
use crate::BoolTensor;
use crate::Element;
use std::convert::TryInto;
/// A tensor or a *n-dimensional* array.
@ -149,8 +149,8 @@ where
/// Applies element wise addition operation with a scalar.
///
/// `y = x + s`
pub fn add_scalar(&self, other: &B::Elem) -> Self {
Self::new(self.value.add_scalar(other))
pub fn add_scalar<E: ElementConversion>(&self, other: E) -> Self {
Self::new(self.value.add_scalar(&other.to_elem()))
}
/// Applies element wise substraction operation.
@ -163,8 +163,8 @@ where
/// Applies element wise substraction operation with a scalar.
///
/// `y = x - s`
pub fn sub_scalar(&self, other: &B::Elem) -> Self {
Self::new(self.value.sub_scalar(other))
pub fn sub_scalar<E: ElementConversion>(&self, other: E) -> Self {
Self::new(self.value.sub_scalar(&other.to_elem()))
}
/// Applies the transpose operation.
@ -206,8 +206,8 @@ where
/// Applies element wise multiplication operation with scalar.
///
/// `y = x2 * x1`
pub fn mul_scalar(&self, other: &B::Elem) -> Self {
Self::new(self.value.mul_scalar(other))
pub fn mul_scalar<E: ElementConversion>(&self, other: E) -> Self {
Self::new(self.value.mul_scalar(&other.to_elem()))
}
/// Applies element wise division operation.
@ -220,8 +220,8 @@ where
/// Applies element wise division operation with scalar.
///
/// `y = x2 / x1`
pub fn div_scalar(&self, other: &B::Elem) -> Self {
Self::new(self.value.div_scalar(other))
pub fn div_scalar<E: ElementConversion>(&self, other: E) -> Self {
Self::new(self.value.div_scalar(&other.to_elem()))
}
/// Aggregate all elements in the tensor with the mean operation.
@ -314,28 +314,28 @@ where
}
/// Applies element wise equal comparison and returns a boolean tensor.
pub fn equal_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.equal_scalar(other))
pub fn equal_scalar<E: ElementConversion>(&self, other: E) -> BoolTensor<B, D> {
BoolTensor::new(self.value.equal_scalar(&other.to_elem()))
}
/// Applies element wise greater comparison and returns a boolean tensor.
pub fn greater_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.greater_scalar(other))
pub fn greater_scalar<E: ElementConversion>(&self, other: E) -> BoolTensor<B, D> {
BoolTensor::new(self.value.greater_scalar(&other.to_elem()))
}
/// Applies element wise greater-equal comparison and returns a boolean tensor.
pub fn greater_equal_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.greater_equal_scalar(other))
pub fn greater_equal_scalar<E: ElementConversion>(&self, other: E) -> BoolTensor<B, D> {
BoolTensor::new(self.value.greater_equal_scalar(&other.to_elem()))
}
/// Applies element wise lower comparison and returns a boolean tensor.
pub fn lower_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.lower_scalar(other))
pub fn lower_scalar<E: ElementConversion>(&self, other: E) -> BoolTensor<B, D> {
BoolTensor::new(self.value.lower_scalar(&other.to_elem()))
}
/// Applies element wise lower-equal comparison and returns a boolean tensor.
pub fn lower_equal_scalar(&self, other: &B::Elem) -> BoolTensor<B, D> {
BoolTensor::new(self.value.lower_equal_scalar(other))
pub fn lower_equal_scalar<E: ElementConversion>(&self, other: E) -> BoolTensor<B, D> {
BoolTensor::new(self.value.lower_equal_scalar(&other.to_elem()))
}
/// Create a random tensor of the given shape where each element is sampled from the given
@ -411,8 +411,8 @@ where
}
/// Fill each element with the given value based on the given mask.
pub fn mask_fill(&self, mask: &BoolTensor<B, D>, value: B::Elem) -> Self {
Self::new(self.value.mask_fill(&mask.value, value))
pub fn mask_fill<E: ElementConversion>(&self, mask: &BoolTensor<B, D>, value: E) -> Self {
Self::new(self.value.mask_fill(&mask.value, value.to_elem()))
}
/// Returns a tensor with full precision based on the selected backend.
@ -540,13 +540,13 @@ where
impl<E, const D: usize, B> std::ops::Add<E> for Tensor<B, D>
where
E: Element,
B: Backend<Elem = E>,
E: ElementConversion,
B: Backend,
{
type Output = Self;
fn add(self, other: E) -> Self {
Tensor::add_scalar(&self, &other)
Tensor::add_scalar(&self, other)
}
}
@ -563,13 +563,13 @@ where
impl<E, const D: usize, B> std::ops::Sub<E> for Tensor<B, D>
where
E: Element,
B: Backend<Elem = E>,
E: ElementConversion,
B: Backend,
{
type Output = Self;
fn sub(self, other: E) -> Self {
Tensor::sub_scalar(&self, &other)
Tensor::sub_scalar(&self, other)
}
}
@ -586,13 +586,13 @@ where
impl<E, const D: usize, B> std::ops::Mul<E> for Tensor<B, D>
where
E: Element,
B: Backend<Elem = E>,
E: ElementConversion,
B: Backend,
{
type Output = Self;
fn mul(self, other: E) -> Self {
Tensor::mul_scalar(&self, &other)
Tensor::mul_scalar(&self, other)
}
}
@ -609,13 +609,13 @@ where
impl<E, const D: usize, B> std::ops::Div<E> for Tensor<B, D>
where
E: Element,
B: Backend<Elem = E>,
E: ElementConversion,
B: Backend,
{
type Output = Self;
fn div(self, other: E) -> Self {
Tensor::div_scalar(&self, &other)
Tensor::div_scalar(&self, other)
}
}

View File

@ -1,4 +1,4 @@
use crate::{backend::Backend, ElementConversion, Tensor};
use crate::{backend::Backend, Tensor};
pub fn var<B: Backend, const D: usize>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let mean = tensor.mean_dim(dim);
@ -32,6 +32,5 @@ pub fn var_with_mean_n<B: Backend, const D: usize>(
dim: usize,
n: usize,
) -> Tensor<B, D> {
let n = (n as f32).to_elem();
tensor.sub(mean).powf(2.0).sum_dim(dim).div_scalar(&n)
tensor.sub(mean).powf(2.0).sum_dim(dim).div_scalar(n as f32)
}

View File

@ -6,7 +6,7 @@ fn test_greater_scalar() {
let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1);
let data_actual = tensor_1.greater_scalar(&4.0);
let data_actual = tensor_1.greater_scalar(4.0);
let data_expected = Data::from([[false, false, false], [false, false, true]]);
assert_eq!(data_expected, data_actual.to_data());
@ -17,7 +17,7 @@ fn test_greater_equal_scalar() {
let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1);
let data_actual = tensor_1.greater_equal_scalar(&4.0);
let data_actual = tensor_1.greater_equal_scalar(4.0);
let data_expected = Data::from([[false, false, false], [false, true, true]]);
assert_eq!(data_expected, data_actual.to_data());
@ -54,7 +54,7 @@ fn test_lower_scalar() {
let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1);
let data_actual = tensor_1.lower_scalar(&4.0);
let data_actual = tensor_1.lower_scalar(4.0);
let data_expected = Data::from([[true, true, true], [true, false, false]]);
assert_eq!(data_expected, data_actual.to_data());
@ -65,7 +65,7 @@ fn test_lower_equal_scalar() {
let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor_1 = Tensor::<TestBackend, 2>::from_data(data_1);
let data_actual = tensor_1.lower_equal_scalar(&4.0);
let data_actual = tensor_1.lower_equal_scalar(4.0);
let data_expected = Data::from([[true, true, true], [true, true, false]]);
assert_eq!(data_expected, data_actual.to_data());

View File

@ -9,7 +9,7 @@ use burn::optim::momentum::MomentumConfig;
use burn::optim::{Optimizer, Sgd, SgdConfig};
use burn::tensor::backend::{ADBackend, Backend};
use burn::tensor::loss::cross_entropy_with_logits;
use burn::tensor::{Data, ElementConversion, Shape, Tensor};
use burn::tensor::{Data, Shape, Tensor};
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric};
use burn::train::{ClassificationLearner, ClassificationOutput, Train};
use burn::train::{SupervisedData, SupervisedTrainerBuilder};
@ -145,7 +145,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
.map(|item| Data::<f32, 2>::from(item.image))
.map(|data| Tensor::<B, 2>::from_data(data.convert()))
.map(|tensor| tensor.reshape(Shape::new([1, 784])))
.map(|tensor| tensor.div_scalar(&255.to_elem()))
.map(|tensor| tensor / 255)
.collect();
let targets = items

View File

@ -2,7 +2,7 @@ use crate as burn;
use crate::config::Config;
use crate::module::Forward;
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, ElementConversion, Tensor};
use crate::tensor::{Distribution, Tensor};
/// Configuration to create a [Dropout](Dropout) layer.
#[derive(Config)]
@ -35,10 +35,10 @@ impl<B: Backend, const D: usize> Forward<Tensor<B, D>, Tensor<B, D>> for Dropout
}
let random = input.random_like(Distribution::Bernoulli(self.prob));
let mask = random.equal_scalar(&1.to_elem());
let x = input.mask_fill(&mask, 0.to_elem());
let mask = random.equal_scalar(1);
let x = input.mask_fill(&mask, 0.0_f32);
x.div_scalar(&(1.0 - self.prob).to_elem())
x / (1.0 - self.prob)
}
}

View File

@ -4,7 +4,7 @@ use crate::config::Config;
use crate::module::Module;
use crate::module::{Forward, Param};
use crate::tensor::backend::Backend;
use crate::tensor::{ElementConversion, Shape, Tensor};
use crate::tensor::{Shape, Tensor};
/// Configuration to create a [LayerNorm](LayerNorm) layer.
#[derive(Config)]
@ -45,7 +45,7 @@ impl<B: Backend, const D: usize> Forward<Tensor<B, D>, Tensor<B, D>> for LayerNo
let input_normalized = input
.sub(&mean)
.div(&var.powf(0.5).add_scalar(&self.epsilon.to_elem()));
.div(&var.powf(0.5).add_scalar(self.epsilon));
input_normalized
.mul(&self.gamma.unsqueeze())

View File

@ -34,7 +34,7 @@ impl<B: ADBackend> WeightDecay<B> {
let id = id.to_string();
let grad = match self.gradients.get::<Tensor<B::InnerBackend, D>>(&id) {
Some(grad_last_step) => grad_last_step.mul_scalar(&self.penalty).add(&grad),
Some(grad_last_step) => grad_last_step.mul_scalar(self.penalty).add(&grad),
None => grad,
};

View File

@ -46,8 +46,8 @@ impl<B: ADBackend> Momentum<B> {
let velocity = match self.velocity.get::<Tensor<B::InnerBackend, D>>(&id) {
Some(grad_last_step) => grad
.mul_scalar(&(1.0 - self.dampening).to_elem())
.add(&grad_last_step.mul_scalar(&self.momentum)),
.mul_scalar(1.0 - self.dampening)
.add(&grad_last_step.mul_scalar(self.momentum)),
None => grad.clone(),
};
@ -55,7 +55,7 @@ impl<B: ADBackend> Momentum<B> {
self.velocity.register_any(id, velocity.clone());
match self.nesterov {
true => velocity.mul_scalar(&self.momentum).add(&grad),
true => velocity.mul_scalar(self.momentum).add(&grad),
false => velocity,
}
}

View File

@ -64,7 +64,7 @@ impl<B: ADBackend> Optimizer for Sgd<B> {
None => grad,
};
let delta = grad.mul_scalar(&self.learning_rate);
let delta = grad.mul_scalar(self.learning_rate);
tensor.update(tensor.inner() - delta);
}