mirror of https://github.com/tracel-ai/burn.git
Feature Addition: PRelu Module (#1328)
This commit is contained in:
parent
1da47c9bf1
commit
8e23057c6b
|
@ -115,6 +115,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
|
||||
| `Dropout` | `nn.Dropout` |
|
||||
| `Gelu` | `nn.Gelu` |
|
||||
| `Prelu` | `nn.PReLu` |
|
||||
| `Linear` | `nn.Linear` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
| `Relu` | `nn.ReLU` |
|
||||
|
|
|
@ -280,3 +280,4 @@ Those operations are only available for `Bool` tensors.
|
|||
| `activation::softmax(tensor, dim)` | Similar to `nn.functional.softmax(tensor, dim)` |
|
||||
| `activation::softplus(tensor, beta)` | Similar to `nn.functional.softplus(tensor, beta)` |
|
||||
| `activation::tanh(tensor)` | Similar to `nn.functional.tanh(tensor)` |
|
||||
| `activation::prelu(tensor,alpha)` | Similar to `nn.functional.prelu(tensor,weight)` |
|
||||
|
|
|
@ -32,6 +32,7 @@ mod tests {
|
|||
|
||||
// test activation
|
||||
burn_tensor::testgen_gelu!();
|
||||
burn_tensor::testgen_prelu!();
|
||||
burn_tensor::testgen_relu!();
|
||||
burn_tensor::testgen_softmax!();
|
||||
burn_tensor::testgen_sigmoid!();
|
||||
|
|
|
@ -24,6 +24,7 @@ mod linear;
|
|||
mod norm;
|
||||
mod padding;
|
||||
mod pos_encoding;
|
||||
mod prelu;
|
||||
mod relu;
|
||||
mod rnn;
|
||||
mod unfold;
|
||||
|
@ -36,6 +37,7 @@ pub use linear::*;
|
|||
pub use norm::*;
|
||||
pub use padding::*;
|
||||
pub use pos_encoding::*;
|
||||
pub use prelu::*;
|
||||
pub use relu::*;
|
||||
pub use rnn::*;
|
||||
pub use unfold::*;
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::Initializer;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
/// Parametric Relu layer.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PRelu<B: Backend> {
|
||||
/// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must
|
||||
/// be the same as number of channels in the input tensor
|
||||
pub alpha: Param<Tensor<B, 1>>,
|
||||
}
|
||||
/// Configuration to create a [Parametric Relu](PRelu) layer.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct PReluConfig {
|
||||
/// The number of parameters.
|
||||
#[config(default = "1")]
|
||||
pub num_parameters: usize,
|
||||
/// The learnable weight alpha. Default is 0.25
|
||||
#[config(default = "0.25")]
|
||||
pub alpha: f64,
|
||||
}
|
||||
impl PReluConfig {
|
||||
/// Initialize a new [Parametric Relu](PRelu) Layer
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> PRelu<B> {
|
||||
PRelu {
|
||||
// alpha is a tensor of length num_parameters
|
||||
alpha: Param::from(
|
||||
Initializer::Constant { value: self.alpha }.init([self.num_parameters], device),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> PRelu<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
crate::tensor::activation::prelu(input, self.alpha.val())
|
||||
}
|
||||
}
|
|
@ -13,6 +13,35 @@ pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
|||
Tensor::from_primitive(B::gelu(tensor.primitive))
|
||||
}
|
||||
|
||||
/// Applies Parametric ReLu activation
|
||||
/// ` PReLu(x) = max(0,x) + \alpha * min(0,x)`
|
||||
/// tensor is assumed to be of shape \[batch_size, channels, ...\]
|
||||
/// alpha is assumed to be of shape \[channels\] or \[1\]
|
||||
pub fn prelu<const D: usize, B: Backend>(
|
||||
tensor: Tensor<B, D>,
|
||||
alpha: Tensor<B, 1>,
|
||||
) -> Tensor<B, D> {
|
||||
check!(TensorCheck::check_prelu_shape::<D>(
|
||||
&tensor.shape(),
|
||||
&alpha.shape()
|
||||
));
|
||||
|
||||
let weight = if alpha.dims()[0] == 1 {
|
||||
// if there is only 1 weight, then reshape it to (1,1,1... D times) so that the rank is D
|
||||
alpha.reshape([1; D])
|
||||
} else {
|
||||
// D>=2 because the case where D==1 and num_weights >1 is handled by check function
|
||||
// there is more than 1 weight and rank is more than 2
|
||||
let num_weights = alpha.dims()[0];
|
||||
let mut s = [1; D];
|
||||
s[1] = num_weights;
|
||||
// reshape the weights to (1, channels,1 ...)
|
||||
alpha.reshape(s)
|
||||
};
|
||||
|
||||
Tensor::from_primitive(B::prelu(tensor.primitive, weight.primitive))
|
||||
}
|
||||
|
||||
/// Applies the softmax function on the input tensor along the given dimension.
|
||||
///
|
||||
/// `softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`
|
||||
|
|
|
@ -694,6 +694,44 @@ impl TensorCheck {
|
|||
|
||||
check
|
||||
}
|
||||
pub(crate) fn check_prelu_shape<const D: usize>(
|
||||
shape_tensor: &Shape<D>,
|
||||
shape_weight: &Shape<1>,
|
||||
) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
if shape_weight.dims[0] == 1 {
|
||||
check
|
||||
} else if D >= 2 {
|
||||
let channels = shape_tensor.dims[1];
|
||||
let num_weights = shape_weight.dims[0];
|
||||
if channels != num_weights {
|
||||
check = check.register(
|
||||
"PReLu",
|
||||
TensorError::new(
|
||||
"Number of channels in input tensor and number of weights must be equal",
|
||||
)
|
||||
.details(format!(
|
||||
"Got no. of channels: {}, no. of weights: {}",
|
||||
channels, num_weights
|
||||
)),
|
||||
);
|
||||
return check;
|
||||
}
|
||||
check
|
||||
} else {
|
||||
check = check.register(
|
||||
"PReLu",
|
||||
TensorError::new(
|
||||
"Number of channels in input tensor and number of weights must be equal",
|
||||
)
|
||||
.details(format!(
|
||||
"Got no. of channels: {}, no. of weights: {}",
|
||||
1, shape_weight.dims[0]
|
||||
)),
|
||||
);
|
||||
check
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks aggregate dimension such as mean and sum.
|
||||
pub(crate) fn aggregate_dim<const D: usize>(ops: &str, dim: usize) -> Self {
|
||||
|
|
|
@ -58,6 +58,18 @@ pub trait ActivationOps<B: Backend> {
|
|||
|
||||
B::float_div_scalar(x, 2i32.elem())
|
||||
}
|
||||
/// Applies the PReLu activation function.
|
||||
/// # Arguments
|
||||
/// * `tensor` - The input tensor
|
||||
/// * `alpha` - The weight tensor
|
||||
fn prelu<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
alpha: FloatTensor<B, D>,
|
||||
) -> FloatTensor<B, D> {
|
||||
let mask = B::float_lower_elem(tensor.clone(), 0.elem());
|
||||
let scaled_tensor = B::float_mul(tensor.clone(), alpha);
|
||||
B::float_mask_where(tensor, mask, scaled_tensor)
|
||||
}
|
||||
|
||||
/// Applies the Gelu activation function backward.
|
||||
///
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
pub(crate) mod gelu;
|
||||
pub(crate) mod mish;
|
||||
pub(crate) mod prelu;
|
||||
pub(crate) mod relu;
|
||||
pub(crate) mod sigmoid;
|
||||
pub(crate) mod silu;
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
#[burn_tensor_testgen::testgen(prelu)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_prelu_2_dimension() {
|
||||
let data = [
|
||||
[-1.1, 0.0, 1.2, 0.25, -5.4],
|
||||
[-4.567, 0.56, -1.55, 99.9, 0.0],
|
||||
];
|
||||
let tensor = TestTensor::from(data);
|
||||
let data_actual =
|
||||
activation::prelu(tensor, TestTensor::from([0.5, 0.25, 0.0, -0.8, -0.4])).into_data();
|
||||
let data_expected = Data::from([
|
||||
[-0.5500, 0.0000, 1.2000, 0.2500, 2.1600],
|
||||
[-2.2835, 0.5600, -0.0000, 99.9000, -0.0000],
|
||||
]);
|
||||
data_expected.assert_approx_eq(&data_actual, 9);
|
||||
}
|
||||
#[test]
|
||||
fn test_prelu_2_dimension_scalar_weight() {
|
||||
let data = [
|
||||
[-1.1, 0.0, 1.2, 0.25, -5.4],
|
||||
[-4.567, 0.56, -1.55, 99.9, 0.0],
|
||||
];
|
||||
let tensor = TestTensor::from(data);
|
||||
let data_actual = activation::prelu(tensor, TestTensor::from([-0.8])).into_data();
|
||||
let data_expected = Data::from([
|
||||
[0.8800, -0.0000, 1.2000, 0.2500, 4.3200],
|
||||
[3.6536, 0.5600, 1.2400, 99.9000, -0.0000],
|
||||
]);
|
||||
data_expected.assert_approx_eq(&data_actual, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prelu_positives() {
|
||||
// Check that positives are untouched
|
||||
let data = [[
|
||||
0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737,
|
||||
]];
|
||||
let tensor = TestTensor::from(data);
|
||||
let data_actual = activation::prelu(tensor, TestTensor::from([0.25])).into_data();
|
||||
let data_expected = Data::from(data);
|
||||
data_expected.assert_approx_eq(&data_actual, 9);
|
||||
}
|
||||
#[test]
|
||||
fn test_prelu_zero_weight() {
|
||||
// test that with weight 0 it behaves as relu
|
||||
let data = [-1.1, 0.0, 1.2, 0.25, -5.4];
|
||||
let tensor = TestTensor::from(data);
|
||||
let data_actual = activation::prelu(tensor, TestTensor::from([0.0])).into_data();
|
||||
let data_expected = Data::from([0.0, 0.0, 1.2, 0.25, 0.0]);
|
||||
data_expected.assert_approx_eq(&data_actual, 9);
|
||||
}
|
||||
#[test]
|
||||
fn test_prelu_some_weight() {
|
||||
// test that with some non zero weight it works like leaky relu
|
||||
let data = [-1.1, 0.0, 1.2, 0.25, -5.4];
|
||||
let tensor = TestTensor::from(data);
|
||||
let data_actual = activation::prelu(tensor, TestTensor::from([0.5])).into_data();
|
||||
let data_expected = Data::from([-0.550, 0.0, 1.20, 0.250, -2.70]);
|
||||
data_expected.assert_approx_eq(&data_actual, 9);
|
||||
}
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_prelu_single_dim_multi_weight() {
|
||||
// should panic because the data has only 1 channel
|
||||
let data = [-1.1, 2.0, 1.2, 0.25, -5.4];
|
||||
let tensor = TestTensor::from(data);
|
||||
let data_actual =
|
||||
activation::prelu(tensor, TestTensor::from([0.5, -0.25, 0.0, 0.5, -1.0])).into_data();
|
||||
let data_expected = Data::from([-0.550, 0.0, 1.20, 0.250, -2.70]);
|
||||
data_expected.assert_approx_eq(&data_actual, 9);
|
||||
}
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_prelu_multi_dim_wrong_weights() {
|
||||
let data = [
|
||||
[-1.1, 0.0, 1.2, 0.25, -5.4],
|
||||
[-4.567, 0.56, -1.55, 99.9, 0.0],
|
||||
];
|
||||
let tensor = TestTensor::from(data);
|
||||
let data_actual = activation::prelu(tensor, TestTensor::from([-0.8, 0.1])).into_data();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue