mirror of https://github.com/tracel-ai/burn.git
Add LeakyReLu implementation (#1208)
* Implement LeakyReLu * Cargo fmt * Apply suggestions * cargo fmt * Use float_mul_scalar * Should be grad * Add to books module * Move test files * Update leaky relu to use activation function * Update tensor.md * Fix failing test due to approx * Add back the function comment * Fix comment per PR feedback --------- Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
This commit is contained in:
parent
626457e1c6
commit
c21d5a3207
|
@ -52,7 +52,7 @@ the `Module` derive, you need to be careful to achieve the behavior you want.
|
|||
These methods are available for all modules.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|-----------------------------------------|------------------------------------------|
|
||||
| --------------------------------------- | ---------------------------------------- |
|
||||
| `module.devices()` | N/A |
|
||||
| `module.fork(device)` | Similar to `module.to(device).detach()` |
|
||||
| `module.to_device(device)` | `module.to(device)` |
|
||||
|
@ -69,7 +69,7 @@ Similar to the backend trait, there is also the `AutodiffModule` trait to signif
|
|||
autodiff support.
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|------------------|--------------------|
|
||||
| ---------------- | ------------------ |
|
||||
| `module.valid()` | `module.eval()` |
|
||||
|
||||
## Visitor & Mapper
|
||||
|
@ -107,25 +107,26 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### General
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|----------------|-----------------------------------------------|
|
||||
| -------------- | --------------------------------------------- |
|
||||
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
|
||||
| `LayerNorm` | `nn.LayerNorm` |
|
||||
| `Dropout` | `nn.Dropout` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
| `Gelu` | `nn.Gelu` |
|
||||
| `GroupNorm` | `nn.GroupNorm` |
|
||||
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
|
||||
| `RmsNorm` | _No direct equivalent_ |
|
||||
| `Dropout` | `nn.Dropout` |
|
||||
| `Gelu` | `nn.Gelu` |
|
||||
| `Prelu` | `nn.PReLu` |
|
||||
| `LayerNorm` | `nn.LayerNorm` |
|
||||
| `LeakyRelu` | `nn.LeakyReLU` |
|
||||
| `LeakyRelu` | `nn.LeakyReLu` |
|
||||
| `Linear` | `nn.Linear` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
| `Prelu` | `nn.PReLu` |
|
||||
| `Relu` | `nn.ReLU` |
|
||||
| `RmsNorm` | _No direct equivalent_ |
|
||||
| `SwiGlu` | _No direct equivalent_ |
|
||||
|
||||
### Convolutions
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|-------------------|----------------------|
|
||||
| ----------------- | -------------------- |
|
||||
| `Conv1d` | `nn.Conv1d` |
|
||||
| `Conv2d` | `nn.Conv2d` |
|
||||
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
|
||||
|
@ -134,7 +135,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Pooling
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|---------------------|------------------------|
|
||||
| ------------------- | ---------------------- |
|
||||
| `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` |
|
||||
| `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` |
|
||||
| `AvgPool1d` | `nn.AvgPool1d` |
|
||||
|
@ -145,7 +146,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### RNNs
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|------------------|------------------------|
|
||||
| ---------------- | ---------------------- |
|
||||
| `Gru` | `nn.GRU` |
|
||||
| `Lstm` | `nn.LSTM` |
|
||||
| `GateController` | _No direct equivalent_ |
|
||||
|
@ -153,7 +154,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Transformer
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|----------------------|-------------------------|
|
||||
| -------------------- | ----------------------- |
|
||||
| `MultiHeadAttention` | `nn.MultiheadAttention` |
|
||||
| `TransformerDecoder` | `nn.TransformerDecoder` |
|
||||
| `TransformerEncoder` | `nn.TransformerEncoder` |
|
||||
|
@ -162,7 +163,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
### Loss
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
|--------------------|-----------------------|
|
||||
| ------------------ | --------------------- |
|
||||
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
|
||||
| `MseLoss` | `nn.MSELoss` |
|
||||
| `HuberLoss` | `nn.HuberLoss` |
|
||||
|
|
|
@ -303,17 +303,18 @@ Those operations are only available for `Bool` tensors.
|
|||
|
||||
## Activation Functions
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ---------------------------------------- | ------------------------------------------ |
|
||||
| `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` |
|
||||
| `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` |
|
||||
| `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` |
|
||||
| `activation::mish(tensor)` | `nn.functional.mish(tensor)` |
|
||||
| `activation::prelu(tensor,alpha)` | `nn.functional.prelu(tensor,weight)` |
|
||||
| `activation::quiet_softmax(tensor, dim)` | `nn.functional.quiet_softmax(tensor, dim)` |
|
||||
| `activation::relu(tensor)` | `nn.functional.relu(tensor)` |
|
||||
| `activation::sigmoid(tensor)` | `nn.functional.sigmoid(tensor)` |
|
||||
| `activation::silu(tensor)` | `nn.functional.silu(tensor)` |
|
||||
| `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` |
|
||||
| `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` |
|
||||
| `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` |
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| ------------------------------------------------ | -------------------------------------------------- |
|
||||
| `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` |
|
||||
| `activation::leaky_relu(tensor, negative_slope)` | `nn.functional.leaky_relu(tensor, negative_slope)` |
|
||||
| `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` |
|
||||
| `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` |
|
||||
| `activation::mish(tensor)` | `nn.functional.mish(tensor)` |
|
||||
| `activation::prelu(tensor,alpha)` | `nn.functional.prelu(tensor,weight)` |
|
||||
| `activation::quiet_softmax(tensor, dim)` | `nn.functional.quiet_softmax(tensor, dim)` |
|
||||
| `activation::relu(tensor)` | `nn.functional.relu(tensor)` |
|
||||
| `activation::sigmoid(tensor)` | `nn.functional.sigmoid(tensor)` |
|
||||
| `activation::silu(tensor)` | `nn.functional.silu(tensor)` |
|
||||
| `activation::softmax(tensor, dim)` | `nn.functional.softmax(tensor, dim)` |
|
||||
| `activation::softplus(tensor, beta)` | `nn.functional.softplus(tensor, beta)` |
|
||||
| `activation::tanh(tensor)` | `nn.functional.tanh(tensor)` |
|
||||
|
|
|
@ -1,43 +1,47 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use crate as burn;
|
||||
use crate::config::Config;
|
||||
use crate::module::Module;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Data;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
/// Leaky ReLu layer.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct LeakyRelu<B: Backend> {
|
||||
/// The weight used in Leaky ReLu
|
||||
pub negative_slope: Tensor<B, 1>,
|
||||
/// The negative slope.
|
||||
pub negative_slope: f64,
|
||||
phantom: PhantomData<B>,
|
||||
}
|
||||
/// Configuration to create a [Leaky Relu](LeakyRelu) layer.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct LeakyReluConfig {
|
||||
/// The negative slope. Default is 0.01
|
||||
#[config(default = "0.01")]
|
||||
pub negative_slope: f32,
|
||||
pub negative_slope: f64,
|
||||
}
|
||||
impl LeakyReluConfig {
|
||||
/// Initialize a new [Leaky Relu](LeakyRelu) Layer
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> LeakyRelu<B> {
|
||||
pub fn init<B: Backend>(&self) -> LeakyRelu<B> {
|
||||
LeakyRelu {
|
||||
negative_slope: Tensor::from_data(Data::from([self.negative_slope]).convert(), device),
|
||||
negative_slope: self.negative_slope,
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> LeakyRelu<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
/// Forward pass for the Leaky ReLu layer.
|
||||
///
|
||||
/// # Shapes
|
||||
/// # Arguments
|
||||
///
|
||||
/// - input: `[..., any]`
|
||||
/// - output: `[..., any]`
|
||||
/// * `input` - The input tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
// leaky relu is a special case of prelu where the weights are all the same. and the
|
||||
// negative_slope is not learnable
|
||||
crate::tensor::activation::prelu(input, self.negative_slope.clone())
|
||||
crate::tensor::activation::leaky_relu(input, self.negative_slope)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,7 +54,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_leaky_relu_forward() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: LeakyRelu<TestBackend> = LeakyReluConfig::new().init(&device);
|
||||
let model: LeakyRelu<TestBackend> = LeakyReluConfig::new().init();
|
||||
let input = Tensor::<TestBackend, 2>::from_data(Data::from([[0.4410, -0.2507]]), &device);
|
||||
let out = model.forward(input);
|
||||
assert_eq!(out.to_data(), Data::from([[0.4410, -0.002507]]));
|
||||
|
@ -83,7 +87,7 @@ mod tests {
|
|||
];
|
||||
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let model: LeakyRelu<TestBackend> = LeakyReluConfig::new().init(&device);
|
||||
let model: LeakyRelu<TestBackend> = LeakyReluConfig::new().init();
|
||||
let input_data = Tensor::<TestBackend, 3>::from_data(Data::from(input), &device);
|
||||
let actual_output = model.forward(input_data);
|
||||
actual_output
|
||||
|
|
|
@ -88,7 +88,7 @@ represent the corresponding Burn Op.
|
|||
| [IsInf][80] | ❌ | ❌ |
|
||||
| [IsNaN][81] | ❌ | ❌ |
|
||||
| [LayerNormalization][82] | ❌ | ✅ |
|
||||
| [LeakyRelu][83] | ❌ | ❌ |
|
||||
| [LeakyRelu][83] | ✅ | ✅ |
|
||||
| [Less][84] | ❌ | ✅ |
|
||||
| [LessOrEqual][85] | ❌ | ✅ |
|
||||
| Linear | ✅ | ✅ |
|
||||
|
|
|
@ -18,7 +18,7 @@ Here is the directory structure of this crate:
|
|||
|
||||
## Setting up your python environment
|
||||
|
||||
You need to install `onnx==1.15.0` and `torch-2.1.1` in your python environment to add a new test
|
||||
You need to install `onnx==1.15.0` and `torch==2.1.1` in your python environment to add a new test
|
||||
|
||||
## Adding new tests
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ fn main() {
|
|||
.input("tests/neg/neg.onnx")
|
||||
.input("tests/recip/recip.onnx")
|
||||
.input("tests/relu/relu.onnx")
|
||||
.input("tests/leaky_relu/leaky_relu.onnx")
|
||||
.input("tests/reshape/reshape.onnx")
|
||||
.input("tests/sigmoid/sigmoid.onnx")
|
||||
.input("tests/softmax/softmax.onnx")
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
pytorch2.1.1:t
|
||||
8
|
||||
input1/relu1/LeakyRelu" LeakyRelu*
|
||||
alpha
|
||||
×#<
|
||||
main_graphZ
|
||||
input
|
||||
|
||||
|
||||
b
|
||||
1
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: leaky_relu.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
# Use default negative_slope of 0.01
|
||||
self.relu1 = nn.LeakyReLU(negative_slope=0.01)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(x)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
torch.set_printoptions(precision=8)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
file_name = "leaky_relu.onnx"
|
||||
test_input = torch.randn(2, 3, device=device)
|
||||
torch.onnx.export(model, test_input, file_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(file_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data of ones: {}".format(test_input))
|
||||
print("Test input data shape of ones: {}".format(test_input.shape))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data shape: {}".format(output.shape))
|
||||
|
||||
print("Test output: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -33,6 +33,7 @@ include_models!(
|
|||
gather,
|
||||
gelu,
|
||||
global_avr_pool,
|
||||
leaky_relu,
|
||||
linear,
|
||||
log_softmax,
|
||||
log,
|
||||
|
@ -485,6 +486,29 @@ mod tests {
|
|||
assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn leaky_relu() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
let model: leaky_relu::Model<Backend> = leaky_relu::Model::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 2>::from_floats(
|
||||
[
|
||||
[0.33669037, 0.0, 0.23446237],
|
||||
[0.23033303, -1.122_856, -0.18632829],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let output = model.forward(input);
|
||||
let expected = Data::from([
|
||||
[0.33669037, 0.0, 0.23446237],
|
||||
[0.23033303, -0.01122_856, -0.0018632829],
|
||||
]);
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn relu() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
|
|
|
@ -30,6 +30,7 @@ pub enum UnaryNodeKind {
|
|||
LogSoftmax,
|
||||
Neg,
|
||||
Reciprocal,
|
||||
LeakyRelu,
|
||||
Relu,
|
||||
Sigmoid,
|
||||
Softmax,
|
||||
|
@ -51,6 +52,7 @@ impl UnaryNodeKind {
|
|||
Self::LogSoftmax => "log_softmax",
|
||||
Self::Neg => "neg",
|
||||
Self::Reciprocal => "reciprocal",
|
||||
Self::LeakyRelu => "leaky_relu",
|
||||
Self::Relu => "relu",
|
||||
Self::Sigmoid => "sigmoid",
|
||||
Self::Softmax => "softmax",
|
||||
|
@ -138,6 +140,12 @@ impl UnaryNode {
|
|||
Self::new(input, output, UnaryNodeKind::Relu, Rc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn leaky_relu(input: Type, output: Type, alpha: f64) -> Self {
|
||||
let alpha = alpha.to_tokens();
|
||||
let function = move |input| quote! { burn::tensor::activation::leaky_relu(#input, #alpha) };
|
||||
Self::new(input, output, UnaryNodeKind::Relu, Rc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn sigmoid(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { burn::tensor::activation::sigmoid(#input) };
|
||||
Self::new(input, output, UnaryNodeKind::Sigmoid, Rc::new(function))
|
||||
|
@ -305,6 +313,26 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_leaky_relu() {
|
||||
one_node_graph(
|
||||
UnaryNode::leaky_relu(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
0.1,
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = burn::tensor::activation::leaky_relu(tensor1, 0.1);
|
||||
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_sigmoid() {
|
||||
one_node_graph(
|
||||
|
|
|
@ -51,6 +51,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
|
|||
NodeType::Transpose => same_as_input(node),
|
||||
NodeType::Unsqueeze => unsqueeze_update_output(node),
|
||||
NodeType::Pow => same_as_input(node),
|
||||
NodeType::LeakyRelu => same_as_input(node),
|
||||
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
|
||||
_ => temporary_pass_through_stub(node),
|
||||
}
|
||||
|
|
|
@ -503,6 +503,20 @@ fn padding_config(pads: &[i64]) -> PaddingConfig2d {
|
|||
}
|
||||
}
|
||||
|
||||
// Create a LeakyReluConfig from the alpha attribute of the node
|
||||
pub fn leaky_relu_config(node: &Node) -> f64 {
|
||||
let mut alpha = 0.01;
|
||||
|
||||
for (key, value) in node.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"alpha" => alpha = value.clone().into_f32() as f64,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
alpha
|
||||
}
|
||||
|
||||
pub fn reshape_config(node: &Node) -> Vec<i64> {
|
||||
let mut allowzero = 0;
|
||||
|
||||
|
|
|
@ -249,6 +249,7 @@ impl OnnxGraph {
|
|||
NodeType::Flatten => graph.register(Self::flatten_conversion(node)),
|
||||
NodeType::GatherElements => graph.register(Self::gather_conversion(node)),
|
||||
NodeType::Log => graph.register(Self::log_conversion(node)),
|
||||
NodeType::LeakyRelu => graph.register(Self::leaky_relu_conversion(node)),
|
||||
NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)),
|
||||
NodeType::Softmax => graph.register(Self::softmax_conversion(node)),
|
||||
NodeType::Sqrt => graph.register(Self::sqrt_conversion(node)),
|
||||
|
@ -397,6 +398,14 @@ impl OnnxGraph {
|
|||
UnaryNode::erf(input, output)
|
||||
}
|
||||
|
||||
fn leaky_relu_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
let alpha = leaky_relu_config(&node);
|
||||
|
||||
UnaryNode::leaky_relu(input, output, alpha)
|
||||
}
|
||||
|
||||
fn relu_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_type();
|
||||
|
|
|
@ -8,6 +8,19 @@ pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
|||
tensor.relu()
|
||||
}
|
||||
|
||||
/// Applies the leaky rectified linear unit function.
|
||||
///
|
||||
/// f(x) = negative_slope * x for x < 0, f(x) = x for x >= 0
|
||||
pub fn leaky_relu<const D: usize, B: Backend>(
|
||||
tensor: Tensor<B, D>,
|
||||
negative_slope: f64,
|
||||
) -> Tensor<B, D> {
|
||||
Tensor::from_primitive(B::leaky_relu(
|
||||
tensor.primitive,
|
||||
crate::ElementConversion::elem(negative_slope),
|
||||
))
|
||||
}
|
||||
|
||||
/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
|
||||
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
Tensor::from_primitive(B::gelu(tensor.primitive))
|
||||
|
|
|
@ -8,6 +8,27 @@ use super::{FloatTensor, FullPrecisionBackend};
|
|||
///
|
||||
/// This trait let backend implementations override activation functions for better performance.
|
||||
pub trait ActivationOps<B: Backend> {
|
||||
/// Applies the LeakyReLU activation function.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
fn leaky_relu<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
negative_slope: super::FloatElem<B>,
|
||||
) -> FloatTensor<B, D> {
|
||||
let mask = B::float_lower_elem(tensor.clone(), 0.elem());
|
||||
let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope.elem());
|
||||
|
||||
// Update the tensor where the values are `< 0` by `tensor * negative_slope`.
|
||||
B::float_mask_where(tensor, mask, scaled_tensor)
|
||||
}
|
||||
|
||||
/// Applies the ReLU activation function.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
#[burn_tensor_testgen::testgen(leaky_relu)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu_d2() {
|
||||
let tensor = TestTensor::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]);
|
||||
|
||||
let data_actual = activation::leaky_relu(tensor, 0.01).into_data();
|
||||
|
||||
let data_expected = Data::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
pub(crate) mod gelu;
|
||||
pub(crate) mod leaky_relu;
|
||||
pub(crate) mod mish;
|
||||
pub(crate) mod prelu;
|
||||
pub(crate) mod relu;
|
||||
|
|
|
@ -12,6 +12,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_gelu!();
|
||||
burn_tensor::testgen_mish!();
|
||||
burn_tensor::testgen_relu!();
|
||||
burn_tensor::testgen_leaky_relu!();
|
||||
burn_tensor::testgen_softmax!();
|
||||
burn_tensor::testgen_softplus!();
|
||||
burn_tensor::testgen_sigmoid!();
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
pytorch2.1.1:t
|
||||
8
|
||||
input1/relu1/LeakyRelu" LeakyRelu*
|
||||
alpha
|
||||
×#<
|
||||
main_graphZ
|
||||
input
|
||||
|
||||
|
||||
b
|
||||
1
|
||||
|
||||
|
||||
B
|
Loading…
Reference in New Issue