Feature Addition: PRelu Module (#1328)

This commit is contained in:
Arjun31415 2024-02-24 20:54:22 +05:30 committed by GitHub
parent 1da47c9bf1
commit 8e23057c6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 218 additions and 0 deletions

View File

@ -115,6 +115,7 @@ Burn comes with built-in modules that you can use to build your own modules.
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
| `Dropout` | `nn.Dropout` |
| `Gelu` | `nn.Gelu` |
| `Prelu` | `nn.PReLu` |
| `Linear` | `nn.Linear` |
| `Embedding` | `nn.Embedding` |
| `Relu` | `nn.ReLU` |

View File

@ -280,3 +280,4 @@ Those operations are only available for `Bool` tensors.
| `activation::softmax(tensor, dim)` | Similar to `nn.functional.softmax(tensor, dim)` |
| `activation::softplus(tensor, beta)` | Similar to `nn.functional.softplus(tensor, beta)` |
| `activation::tanh(tensor)` | Similar to `nn.functional.tanh(tensor)` |
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |

View File

@ -32,6 +32,7 @@ mod tests {
// test activation
burn_tensor::testgen_gelu!();
burn_tensor::testgen_prelu!();
burn_tensor::testgen_relu!();
burn_tensor::testgen_softmax!();
burn_tensor::testgen_sigmoid!();

View File

@ -24,6 +24,7 @@ mod linear;
mod norm;
mod padding;
mod pos_encoding;
mod prelu;
mod relu;
mod rnn;
mod unfold;
@ -36,6 +37,7 @@ pub use linear::*;
pub use norm::*;
pub use padding::*;
pub use pos_encoding::*;
pub use prelu::*;
pub use relu::*;
pub use rnn::*;
pub use unfold::*;

View File

@ -0,0 +1,47 @@
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::nn::Initializer;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
/// Parametric Relu layer.
#[derive(Module, Debug)]
pub struct PRelu<B: Backend> {
/// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must
/// be the same as number of channels in the input tensor
pub alpha: Param<Tensor<B, 1>>,
}
/// Configuration to create a [Parametric Relu](PRelu) layer.
#[derive(Config, Debug)]
pub struct PReluConfig {
/// The number of parameters.
#[config(default = "1")]
pub num_parameters: usize,
/// The learnable weight alpha. Default is 0.25
#[config(default = "0.25")]
pub alpha: f64,
}
impl PReluConfig {
/// Initialize a new [Parametric Relu](PRelu) Layer
pub fn init<B: Backend>(&self, device: &B::Device) -> PRelu<B> {
PRelu {
// alpha is a tensor of length num_parameters
alpha: Param::from(
Initializer::Constant { value: self.alpha }.init([self.num_parameters], device),
),
}
}
}
impl<B: Backend> PRelu<B> {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
crate::tensor::activation::prelu(input, self.alpha.val())
}
}

View File

@ -13,6 +13,35 @@ pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
Tensor::from_primitive(B::gelu(tensor.primitive))
}
/// Applies Parametric ReLu activation
/// ` PReLu(x) = max(0,x) + \alpha * min(0,x)`
/// tensor is assumed to be of shape \[batch_size, channels, ...\]
/// alpha is assumed to be of shape \[channels\] or \[1\]
pub fn prelu<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
alpha: Tensor<B, 1>,
) -> Tensor<B, D> {
check!(TensorCheck::check_prelu_shape::<D>(
&tensor.shape(),
&alpha.shape()
));
let weight = if alpha.dims()[0] == 1 {
// if there is only 1 weight, then reshape it to (1,1,1... D times) so that the rank is D
alpha.reshape([1; D])
} else {
// D>=2 because the case where D==1 and num_weights >1 is handled by check function
// there is more than 1 weight and rank is more than 2
let num_weights = alpha.dims()[0];
let mut s = [1; D];
s[1] = num_weights;
// reshape the weights to (1, channels,1 ...)
alpha.reshape(s)
};
Tensor::from_primitive(B::prelu(tensor.primitive, weight.primitive))
}
/// Applies the softmax function on the input tensor along the given dimension.
///
/// `softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`

View File

@ -694,6 +694,44 @@ impl TensorCheck {
check
}
pub(crate) fn check_prelu_shape<const D: usize>(
shape_tensor: &Shape<D>,
shape_weight: &Shape<1>,
) -> Self {
let mut check = Self::Ok;
if shape_weight.dims[0] == 1 {
check
} else if D >= 2 {
let channels = shape_tensor.dims[1];
let num_weights = shape_weight.dims[0];
if channels != num_weights {
check = check.register(
"PReLu",
TensorError::new(
"Number of channels in input tensor and number of weights must be equal",
)
.details(format!(
"Got no. of channels: {}, no. of weights: {}",
channels, num_weights
)),
);
return check;
}
check
} else {
check = check.register(
"PReLu",
TensorError::new(
"Number of channels in input tensor and number of weights must be equal",
)
.details(format!(
"Got no. of channels: {}, no. of weights: {}",
1, shape_weight.dims[0]
)),
);
check
}
}
/// Checks aggregate dimension such as mean and sum.
pub(crate) fn aggregate_dim<const D: usize>(ops: &str, dim: usize) -> Self {

View File

@ -58,6 +58,18 @@ pub trait ActivationOps<B: Backend> {
B::float_div_scalar(x, 2i32.elem())
}
/// Applies the PReLu activation function.
/// # Arguments
/// * `tensor` - The input tensor
/// * `alpha` - The weight tensor
fn prelu<const D: usize>(
tensor: FloatTensor<B, D>,
alpha: FloatTensor<B, D>,
) -> FloatTensor<B, D> {
let mask = B::float_lower_elem(tensor.clone(), 0.elem());
let scaled_tensor = B::float_mul(tensor.clone(), alpha);
B::float_mask_where(tensor, mask, scaled_tensor)
}
/// Applies the Gelu activation function backward.
///

View File

@ -1,5 +1,6 @@
pub(crate) mod gelu;
pub(crate) mod mish;
pub(crate) mod prelu;
pub(crate) mod relu;
pub(crate) mod sigmoid;
pub(crate) mod silu;

View File

@ -0,0 +1,86 @@
#[burn_tensor_testgen::testgen(prelu)]
mod tests {
use super::*;
use burn_tensor::{activation, Data, Tensor};
#[test]
fn test_prelu_2_dimension() {
let data = [
[-1.1, 0.0, 1.2, 0.25, -5.4],
[-4.567, 0.56, -1.55, 99.9, 0.0],
];
let tensor = TestTensor::from(data);
let data_actual =
activation::prelu(tensor, TestTensor::from([0.5, 0.25, 0.0, -0.8, -0.4])).into_data();
let data_expected = Data::from([
[-0.5500, 0.0000, 1.2000, 0.2500, 2.1600],
[-2.2835, 0.5600, -0.0000, 99.9000, -0.0000],
]);
data_expected.assert_approx_eq(&data_actual, 9);
}
#[test]
fn test_prelu_2_dimension_scalar_weight() {
let data = [
[-1.1, 0.0, 1.2, 0.25, -5.4],
[-4.567, 0.56, -1.55, 99.9, 0.0],
];
let tensor = TestTensor::from(data);
let data_actual = activation::prelu(tensor, TestTensor::from([-0.8])).into_data();
let data_expected = Data::from([
[0.8800, -0.0000, 1.2000, 0.2500, 4.3200],
[3.6536, 0.5600, 1.2400, 99.9000, -0.0000],
]);
data_expected.assert_approx_eq(&data_actual, 7);
}
#[test]
fn test_prelu_positives() {
// Check that positives are untouched
let data = [[
0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737,
]];
let tensor = TestTensor::from(data);
let data_actual = activation::prelu(tensor, TestTensor::from([0.25])).into_data();
let data_expected = Data::from(data);
data_expected.assert_approx_eq(&data_actual, 9);
}
#[test]
fn test_prelu_zero_weight() {
// test that with weight 0 it behaves as relu
let data = [-1.1, 0.0, 1.2, 0.25, -5.4];
let tensor = TestTensor::from(data);
let data_actual = activation::prelu(tensor, TestTensor::from([0.0])).into_data();
let data_expected = Data::from([0.0, 0.0, 1.2, 0.25, 0.0]);
data_expected.assert_approx_eq(&data_actual, 9);
}
#[test]
fn test_prelu_some_weight() {
// test that with some non zero weight it works like leaky relu
let data = [-1.1, 0.0, 1.2, 0.25, -5.4];
let tensor = TestTensor::from(data);
let data_actual = activation::prelu(tensor, TestTensor::from([0.5])).into_data();
let data_expected = Data::from([-0.550, 0.0, 1.20, 0.250, -2.70]);
data_expected.assert_approx_eq(&data_actual, 9);
}
#[test]
#[should_panic]
fn test_prelu_single_dim_multi_weight() {
// should panic because the data has only 1 channel
let data = [-1.1, 2.0, 1.2, 0.25, -5.4];
let tensor = TestTensor::from(data);
let data_actual =
activation::prelu(tensor, TestTensor::from([0.5, -0.25, 0.0, 0.5, -1.0])).into_data();
let data_expected = Data::from([-0.550, 0.0, 1.20, 0.250, -2.70]);
data_expected.assert_approx_eq(&data_actual, 9);
}
#[test]
#[should_panic]
fn test_prelu_multi_dim_wrong_weights() {
let data = [
[-1.1, 0.0, 1.2, 0.25, -5.4],
[-4.567, 0.56, -1.55, 99.9, 0.0],
];
let tensor = TestTensor::from(data);
let data_actual = activation::prelu(tensor, TestTensor::from([-0.8, 0.1])).into_data();
}
}