Feat/gelu (#45)

This commit is contained in:
Nathaniel Simard 2022-09-24 13:08:08 -04:00 committed by GitHub
parent a84df25d40
commit fe5ed0dbb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 57 additions and 3 deletions

View File

@ -1,4 +1,5 @@
use burn_tensor::{activation, backend, Data, Distribution, Shape, Tensor};
use rand::{rngs::StdRng, SeedableRng};
fn loss<B: backend::Backend>(x: &Tensor<B, 2>, y: &Tensor<B, 2>) -> Tensor<B, 2> {
let z = x.matmul(y);
@ -35,8 +36,9 @@ fn run<B: backend::Backend>(x: Data<B::Elem, 2>, y: Data<B::Elem, 2>) {
fn main() {
// Same data for all backends
let x = Data::random(Shape::new([2, 3]), Distribution::Standard);
let y = Data::random(Shape::new([3, 1]), Distribution::Standard);
let mut rng = StdRng::from_entropy();
let x = Data::random(Shape::new([2, 3]), Distribution::Standard, &mut rng);
let y = Data::random(Shape::new([3, 1]), Distribution::Standard, &mut rng);
#[cfg(feature = "ndarray")]
{

View File

@ -1,15 +1,27 @@
use crate::backend::Backend;
use crate::Tensor;
use crate::{ElementPrecision, Precision};
use crate::{ElementConversion, ElementPrecision, Precision};
/// Applies the rectified linear unit function.
pub fn relu<const D: usize, B: Backend>(tensor: &Tensor<B, D>) -> Tensor<B, D> {
tensor.relu()
}
/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
pub fn gelu<const D: usize, B: Backend>(tensor: &Tensor<B, D>) -> Tensor<B, D> {
let x = tensor
.div_scalar(&2.0_f32.sqrt().to_elem())
.erf()
.add_scalar(&1.0_f32.to_elem());
tensor.mul(&x).mul_scalar(&0.5_f32.to_elem())
}
/// Applies the softmax function.
pub fn softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
log_softmax(tensor, dim).exp()
}
/// Applies the log softmax function.
pub fn log_softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let tensor_tmp = match B::Elem::precision() {
Precision::Half => {

View File

@ -0,0 +1,18 @@
use super::super::TestBackend;
use burn_tensor::activation;
use burn_tensor::{Data, Tensor};
#[test]
fn test_gelu() {
let data = Data::from([[
0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737,
]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = activation::gelu(&tensor).to_data();
let data_expected = Data::from([[
0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051,
]]);
data_expected.assert_approx_eq(&data_actual, 3);
}

View File

@ -1,2 +1,3 @@
mod gelu;
mod relu;
mod softmax;

19
burn/src/nn/gelu.rs Normal file
View File

@ -0,0 +1,19 @@
use crate::module::Forward;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
/// Applies the Gaussian Error Linear Units function element-wise.
#[derive(Clone, Debug, Default)]
pub struct GELU {}
impl GELU {
pub fn new() -> Self {
Self {}
}
}
impl<B: Backend, const D: usize> Forward<Tensor<B, D>, Tensor<B, D>> for GELU {
fn forward(&self, input: Tensor<B, D>) -> Tensor<B, D> {
crate::tensor::activation::gelu(&input)
}
}

View File

@ -1,9 +1,11 @@
mod dropout;
mod gelu;
mod layer_norm;
mod linear;
mod relu;
pub use dropout::*;
pub use gelu::*;
pub use layer_norm::*;
pub use linear::*;
pub use relu::*;