diff --git a/burn-tensor/examples/autodiff_simple.rs b/burn-tensor/examples/autodiff_simple.rs index c9e7614b0..9907ba0c1 100644 --- a/burn-tensor/examples/autodiff_simple.rs +++ b/burn-tensor/examples/autodiff_simple.rs @@ -1,4 +1,5 @@ use burn_tensor::{activation, backend, Data, Distribution, Shape, Tensor}; +use rand::{rngs::StdRng, SeedableRng}; fn loss(x: &Tensor, y: &Tensor) -> Tensor { let z = x.matmul(y); @@ -35,8 +36,9 @@ fn run(x: Data, y: Data) { fn main() { // Same data for all backends - let x = Data::random(Shape::new([2, 3]), Distribution::Standard); - let y = Data::random(Shape::new([3, 1]), Distribution::Standard); + let mut rng = StdRng::from_entropy(); + let x = Data::random(Shape::new([2, 3]), Distribution::Standard, &mut rng); + let y = Data::random(Shape::new([3, 1]), Distribution::Standard, &mut rng); #[cfg(feature = "ndarray")] { diff --git a/burn-tensor/src/tensor/activation/base.rs b/burn-tensor/src/tensor/activation/base.rs index bee6ce830..af2d2009b 100644 --- a/burn-tensor/src/tensor/activation/base.rs +++ b/burn-tensor/src/tensor/activation/base.rs @@ -1,15 +1,27 @@ use crate::backend::Backend; use crate::Tensor; -use crate::{ElementPrecision, Precision}; +use crate::{ElementConversion, ElementPrecision, Precision}; +/// Applies the rectified linear unit function. pub fn relu(tensor: &Tensor) -> Tensor { tensor.relu() } +/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf). +pub fn gelu(tensor: &Tensor) -> Tensor { + let x = tensor + .div_scalar(&2.0_f32.sqrt().to_elem()) + .erf() + .add_scalar(&1.0_f32.to_elem()); + tensor.mul(&x).mul_scalar(&0.5_f32.to_elem()) +} + +/// Applies the softmax function. pub fn softmax(tensor: &Tensor, dim: usize) -> Tensor { log_softmax(tensor, dim).exp() } +/// Applies the log softmax function. pub fn log_softmax(tensor: &Tensor, dim: usize) -> Tensor { let tensor_tmp = match B::Elem::precision() { Precision::Half => { diff --git a/burn-tensor/tests/tensor/activation/gelu.rs b/burn-tensor/tests/tensor/activation/gelu.rs new file mode 100644 index 000000000..65813e74a --- /dev/null +++ b/burn-tensor/tests/tensor/activation/gelu.rs @@ -0,0 +1,18 @@ +use super::super::TestBackend; +use burn_tensor::activation; +use burn_tensor::{Data, Tensor}; + +#[test] +fn test_gelu() { + let data = Data::from([[ + 0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737, + ]]); + let tensor = Tensor::::from_data(data); + + let data_actual = activation::gelu(&tensor).to_data(); + + let data_expected = Data::from([[ + 0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051, + ]]); + data_expected.assert_approx_eq(&data_actual, 3); +} diff --git a/burn-tensor/tests/tensor/activation/mod.rs b/burn-tensor/tests/tensor/activation/mod.rs index 6eab47fac..06405bbe9 100644 --- a/burn-tensor/tests/tensor/activation/mod.rs +++ b/burn-tensor/tests/tensor/activation/mod.rs @@ -1,2 +1,3 @@ +mod gelu; mod relu; mod softmax; diff --git a/burn/src/nn/gelu.rs b/burn/src/nn/gelu.rs new file mode 100644 index 000000000..9d203589d --- /dev/null +++ b/burn/src/nn/gelu.rs @@ -0,0 +1,19 @@ +use crate::module::Forward; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; + +/// Applies the Gaussian Error Linear Units function element-wise. +#[derive(Clone, Debug, Default)] +pub struct GELU {} + +impl GELU { + pub fn new() -> Self { + Self {} + } +} + +impl Forward, Tensor> for GELU { + fn forward(&self, input: Tensor) -> Tensor { + crate::tensor::activation::gelu(&input) + } +} diff --git a/burn/src/nn/mod.rs b/burn/src/nn/mod.rs index 3f2e8019e..210029b37 100644 --- a/burn/src/nn/mod.rs +++ b/burn/src/nn/mod.rs @@ -1,9 +1,11 @@ mod dropout; +mod gelu; mod layer_norm; mod linear; mod relu; pub use dropout::*; +pub use gelu::*; pub use layer_norm::*; pub use linear::*; pub use relu::*;