Add Hard sigmoid activation function (#2112)

* Add Hard Sigmoid activation function

* Add ONNX import conversion for HardSigmoid

* Update supported operators list

* Update book

* Make test comparison approximate to eliminate precision issues

* Add burn-candle test

* Fix name in E2E test generator
This commit is contained in:
Genna Wingert 2024-08-07 20:01:42 +02:00 committed by GitHub
parent af8c3150c9
commit a01004dd4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 314 additions and 5 deletions

View File

@ -168,6 +168,7 @@ Burn comes with built-in modules that you can use to build your own modules.
| `Embedding` | `nn.Embedding` |
| `Gelu` | `nn.Gelu` |
| `GroupNorm` | `nn.GroupNorm` |
| `HardSigmoid` | `nn.Hardsigmoid` |
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
| `LayerNorm` | `nn.LayerNorm` |
| `LeakyRelu` | `nn.LeakyReLU` |

View File

@ -319,6 +319,7 @@ strategies.
| Burn API | PyTorch Equivalent |
| ------------------------------------------------ | -------------------------------------------------- |
| `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` |
| `activation::hard_sigmoid(tensor, alpha, beta) | `nn.functional.hardsigmoid(tensor)` |
| `activation::leaky_relu(tensor, negative_slope)` | `nn.functional.leaky_relu(tensor, negative_slope)` |
| `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` |

View File

@ -39,6 +39,7 @@ mod tests {
burn_tensor::testgen_relu!();
burn_tensor::testgen_softmax!();
burn_tensor::testgen_sigmoid!();
burn_tensor::testgen_hard_sigmoid!();
burn_tensor::testgen_silu!();
// test module

View File

@ -0,0 +1,94 @@
use burn_tensor::activation::hard_sigmoid;
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
/// Hard Sigmoid layer.
///
/// Should be created with [HardSigmoidConfig](HardSigmoidConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct HardSigmoid {
/// The alpha value.
pub alpha: f64,
/// The beta value.
pub beta: f64,
}
/// Configuration to create a [Hard Sigmoid](HardSigmoid) layer using the [init function](HardSigmoidConfig::init).
#[derive(Config, Debug)]
pub struct HardSigmoidConfig {
/// The alpha value. Default is 0.2
#[config(default = "0.2")]
pub alpha: f64,
/// The beta value. Default is 0.5
#[config(default = "0.5")]
pub beta: f64,
}
impl HardSigmoidConfig {
/// Initialize a new [Hard Sigmoid](HardSigmoid) Layer
pub fn init(&self) -> HardSigmoid {
HardSigmoid {
alpha: self.alpha,
beta: self.beta,
}
}
}
impl ModuleDisplay for HardSigmoid {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("alpha", &self.alpha)
.add("beta", &self.beta)
.optional()
}
}
impl HardSigmoid {
/// Forward pass for the Hard Sigmoid layer.
///
/// See [hard_sigmoid](crate::tensor::activation::hard_sigmoid) for more information.
///
/// # Shapes
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
hard_sigmoid(input, self.alpha, self.beta)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::TensorData;
use crate::TestBackend;
#[test]
fn test_hard_sigmoid_forward() {
let device = <TestBackend as Backend>::Device::default();
let model: HardSigmoid = HardSigmoidConfig::new().init();
let input =
Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
let out = model.forward(input);
let expected = TensorData::from([[0.5882, 0.44986]]);
out.to_data().assert_approx_eq(&expected, 4);
}
#[test]
fn display() {
let config = HardSigmoidConfig::new().init();
assert_eq!(
alloc::format!("{}", config),
"HardSigmoid {alpha: 0.2, beta: 0.5}"
);
}
}

View File

@ -22,6 +22,7 @@ pub mod interpolate;
mod dropout;
mod embedding;
mod gelu;
mod hard_sigmoid;
mod initializer;
mod leaky_relu;
mod linear;
@ -40,6 +41,7 @@ mod unfold;
pub use dropout::*;
pub use embedding::*;
pub use gelu::*;
pub use hard_sigmoid::*;
pub use initializer::*;
pub use leaky_relu::*;
pub use linear::*;

View File

@ -81,7 +81,7 @@ represent the corresponding Burn Op.
| [HammingWindow][71] | ❌ | ❌ |
| [HannWindow][72] | ❌ | ❌ |
| [Hardmax][73] | ❌ | ❌ |
| [HardSigmoid][74] | ❌ | ❌ |
| [HardSigmoid][74] | ✅ | ✅ |
| [HardSwish][75] | ❌ | ❌ |
| [Identity][76] | ✅ | ✅ |
| [If][77] | ❌ | ✅ |

View File

@ -38,6 +38,7 @@ fn main() {
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")
.input("tests/hard_sigmoid/hard_sigmoid.onnx")
.input("tests/layer_norm/layer_norm.onnx")
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/less/less.onnx")

View File

@ -0,0 +1,13 @@
pytorch2.3.1:
C
input1/hardsigmoid1/HardSigmoid" HardSigmoid*
alpha«ª*> 
main_graphZ
input


b
1


B

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python3
# used to generate model: hard_sigmoid.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.hardsigmoid1 = nn.Hardsigmoid()
def forward(self, x):
x = self.hardsigmoid1(x)
return x
def main():
# Set seed for reproducibility
torch.manual_seed(42)
torch.set_printoptions(precision=8)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
file_name = "hard_sigmoid.onnx"
test_input = torch.randn(2, 3, device=device)
torch.onnx.export(model, test_input, file_name, verbose=False, opset_version=16)
print("Finished exporting model to {}".format(file_name))
# Output some test data for use in the test
print("Test input data of ones: {}".format(test_input))
print("Test input data shape of ones: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))
if __name__ == "__main__":
main()

View File

@ -47,6 +47,7 @@ include_models!(
global_avr_pool,
greater,
greater_or_equal,
hard_sigmoid,
layer_norm,
leaky_relu,
less,
@ -1216,6 +1217,29 @@ mod tests {
output.to_data().assert_approx_eq(&expected, 7);
}
#[test]
fn hard_sigmoid() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: hard_sigmoid::Model<Backend> = hard_sigmoid::Model::new(&device);
// Run the model
let input = Tensor::<Backend, 2>::from_floats(
[
[0.33669037, 0.12880941, 0.23446237],
[0.23033303, -1.12285638, -0.18632829],
],
&device,
);
let output = model.forward(input);
let expected = TensorData::from([
[0.55611509, 0.52146822, 0.53907704],
[0.53838885, 0.31285727, 0.46894526],
]);
output.to_data().assert_approx_eq(&expected, 7);
}
#[test]
fn sin() {
let device = Default::default();

View File

@ -28,6 +28,7 @@ pub enum UnaryNodeKind {
Flatten,
Gelu,
LeakyRelu,
HardSigmoid,
Log,
LogSoftmax,
Neg,
@ -59,6 +60,7 @@ impl UnaryNodeKind {
Self::Flatten => "flatten",
Self::Gelu => "gelu",
Self::LeakyRelu => "leaky_relu",
Self::HardSigmoid => "hard_sigmoid",
Self::Log => "log",
Self::LogSoftmax => "log_softmax",
Self::Neg => "neg",
@ -190,6 +192,14 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Sigmoid, Rc::new(function))
}
pub(crate) fn hard_sigmoid(input: Type, output: Type, alpha: f64, beta: f64) -> Self {
let alpha = alpha.to_tokens();
let beta = beta.to_tokens();
let function =
move |input| quote! { burn::tensor::activation::hard_sigmoid(#input, #alpha, #beta) };
Self::new(input, output, UnaryNodeKind::HardSigmoid, Rc::new(function))
}
pub(crate) fn log_softmax(input: Type, output: Type, dim: usize) -> Self {
let dim = dim.to_tokens();
let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) };
@ -578,6 +588,27 @@ mod tests {
);
}
#[test]
fn test_unary_codegen_hard_sigmoid() {
one_node_graph(
UnaryNode::hard_sigmoid(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
0.2,
0.5,
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = burn::tensor::activation::hard_sigmoid(tensor1, 0.2, 0.5);
tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
#[test]
fn test_unary_codegen_log_softmax() {
one_node_graph(

View File

@ -940,6 +940,22 @@ pub fn leaky_relu_config(node: &Node) -> f64 {
alpha
}
// Create a HardSigmoidConfig from the alpha and beta attributes of the node
pub fn hard_sigmoid_config(node: &Node) -> (f64, f64) {
let mut alpha = 0.2;
let mut beta = 0.5;
for (key, value) in node.attrs.iter() {
match key.as_str() {
"alpha" => alpha = value.clone().into_f32() as f64,
"beta" => beta = value.clone().into_f32() as f64,
_ => {}
}
}
(alpha, beta)
}
pub fn reshape_config(node: &Node) -> Vec<i64> {
let mut allowzero = 0;

View File

@ -63,10 +63,10 @@ use super::op_configuration::{
argmax_config, avg_pool1d_config, avg_pool2d_config, batch_norm_config, clip_config,
concat_config, conv1d_config, conv2d_config, conv3d_config, conv_transpose2d_config,
conv_transpose3d_config, dropout_config, expand_config, flatten_config, gather_config,
layer_norm_config, leaky_relu_config, linear_config, log_softmax_config, max_pool1d_config,
max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config, reduce_min_config,
reduce_prod_config, reduce_sum_config, reshape_config, resize_config, shape_config,
slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config,
hard_sigmoid_config, layer_norm_config, leaky_relu_config, linear_config, log_softmax_config,
max_pool1d_config, max_pool2d_config, pad_config, reduce_max_config, reduce_mean_config,
reduce_min_config, reduce_prod_config, reduce_sum_config, reshape_config, resize_config,
shape_config, slice_config, softmax_config, squeeze_config, transpose_config, unsqueeze_config,
};
use onnx_ir::{
convert_constant_value,
@ -292,6 +292,7 @@ impl ParsedOnnxGraph {
NodeType::Flatten => graph.register(Self::flatten_conversion(node)),
NodeType::Gather => graph.register(Self::gather_conversion(node)),
NodeType::GatherElements => graph.register(Self::gather_elements_conversion(node)),
NodeType::HardSigmoid => graph.register(Self::hard_sigmoid_conversion(node)),
NodeType::Log => graph.register(Self::log_conversion(node)),
NodeType::LeakyRelu => graph.register(Self::leaky_relu_conversion(node)),
NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)),
@ -574,6 +575,14 @@ impl ParsedOnnxGraph {
UnaryNode::leaky_relu(input, output, alpha)
}
fn hard_sigmoid_conversion(node: Node) -> UnaryNode {
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());
let (alpha, beta) = hard_sigmoid_config(&node);
UnaryNode::hard_sigmoid(input, output, alpha, beta)
}
fn relu_conversion(node: Node) -> UnaryNode {
let input = Type::from(node.inputs.first().unwrap());
let output = Type::from(node.outputs.first().unwrap());

View File

@ -130,6 +130,19 @@ pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D>
)))
}
/// Applies the hard sigmoid function
pub fn hard_sigmoid<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
alpha: f64,
beta: f64,
) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid(
tensor.primitive.tensor(),
crate::ElementConversion::elem(alpha),
crate::ElementConversion::elem(beta),
)))
}
/// Applies the log sigmoid function.
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid(

View File

@ -182,6 +182,36 @@ pub trait ActivationOps<B: Backend> {
B::float_mul(value, grad)
}
/// Applies the hard Sigmoid activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `alpha` - The alpha value that the tensor is multiplied with.
/// * `beta` - The beta value that is added to the tensor
///
/// # Returns
///
/// The output tensor.
fn hard_sigmoid<const D: usize>(
tensor: FloatTensor<B, D>,
alpha: super::FloatElem<B>,
beta: super::FloatElem<B>,
) -> FloatTensor<B, D> {
let tensor_full = B::float_into_full_precision(tensor);
let tensor_tmp = FullPrecisionBackend::<B>::float_clamp(
FullPrecisionBackend::<B>::float_add_scalar(
FullPrecisionBackend::<B>::float_mul_scalar(tensor_full, alpha.elem()),
beta.elem(),
),
0.0.elem(),
1.0.elem(),
);
B::float_from_full_precision(tensor_tmp)
}
/// Applies the LogSigmoid activation function.
///
/// # Arguments

View File

@ -0,0 +1,25 @@
#[burn_tensor_testgen::testgen(hard_sigmoid)]
mod tests {
use super::*;
use burn_tensor::{activation, Tensor, TensorData};
#[test]
fn test_hard_sigmoid() {
let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);
let output = activation::hard_sigmoid(tensor, 0.2, 0.5);
let expected = TensorData::from([[0.7, 1.0], [1.0, 0.0]]);
output.into_data().assert_approx_eq(&expected, 4);
}
#[test]
fn test_hard_sigmoid_overflow() {
let tensor = TestTensor::<1>::from([f32::MAX, f32::MIN]);
let output = activation::hard_sigmoid(tensor, 0.2, 0.5);
let expected = TensorData::from([1.0, 0.0]);
output.into_data().assert_approx_eq(&expected, 4);
}
}

View File

@ -1,4 +1,5 @@
pub(crate) mod gelu;
pub(crate) mod hard_sigmoid;
pub(crate) mod leaky_relu;
pub(crate) mod log_sigmoid;
pub(crate) mod mish;

View File

@ -34,6 +34,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::Gelu => same_as_input(node),
NodeType::Gather => gather_update_outputs(node),
NodeType::GatherElements => same_as_input(node),
NodeType::HardSigmoid => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node),
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
NodeType::LayerNormalization => same_as_input(node),