mirror of https://github.com/tracel-ai/burn.git
Implement Quiet Softmax (`Attention Is Off By One`) (#692)
* Added quiet_softmax * Undid bad formatting --------- Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
This commit is contained in:
parent
f0c75aa748
commit
03af140e12
|
@ -46,4 +46,27 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[30.5984, -47.2267], [55.9631, -56.5914]]), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quiet_softmax_grad() {
|
||||
let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1).require_grad();
|
||||
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,14 @@ pub struct MultiHeadAttentionConfig {
|
|||
/// A value too low might result in NaN.
|
||||
#[config(default = -1.0e4)]
|
||||
min_float: f64,
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
///
|
||||
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
|
||||
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
|
||||
///
|
||||
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
|
||||
#[config(default = false)]
|
||||
quiet_softmax: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
|
||||
|
@ -51,6 +59,7 @@ pub struct MultiHeadAttention<B: Backend> {
|
|||
n_heads: usize,
|
||||
d_k: usize,
|
||||
min_float: f64,
|
||||
quiet_softmax: bool,
|
||||
}
|
||||
|
||||
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
|
||||
|
@ -82,6 +91,7 @@ impl MultiHeadAttentionConfig {
|
|||
n_heads: self.n_heads,
|
||||
d_k: self.d_model / self.n_heads,
|
||||
min_float: self.min_float,
|
||||
quiet_softmax: self.quiet_softmax,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,6 +115,7 @@ impl MultiHeadAttentionConfig {
|
|||
n_heads: self.n_heads,
|
||||
d_k: self.d_model / self.n_heads,
|
||||
min_float: self.min_float,
|
||||
quiet_softmax: self.quiet_softmax,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -249,7 +260,11 @@ impl<B: Backend> MultiHeadAttention<B> {
|
|||
);
|
||||
}
|
||||
|
||||
activation::softmax(attn_scores, 3)
|
||||
if self.quiet_softmax {
|
||||
activation::quiet_softmax(attn_scores, 3)
|
||||
} else {
|
||||
activation::softmax(attn_scores, 3)
|
||||
}
|
||||
}
|
||||
|
||||
fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
|
||||
|
|
|
@ -45,10 +45,10 @@ impl BinaryCrossEntropyLossConfig {
|
|||
fn assertions(&self) {
|
||||
if let Some(alpha) = self.smoothing {
|
||||
assert!(
|
||||
(0.0..=1.).contains(&alpha),
|
||||
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
|
||||
alpha
|
||||
);
|
||||
(0.0..=1.).contains(&alpha),
|
||||
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
|
||||
alpha
|
||||
);
|
||||
};
|
||||
if let Some(weights) = self.weights.as_ref() {
|
||||
assert!(
|
||||
|
|
|
@ -53,10 +53,10 @@ impl CrossEntropyLossConfig {
|
|||
fn assertions(&self) {
|
||||
if let Some(alpha) = self.smoothing {
|
||||
assert!(
|
||||
(0.0..=1.).contains(&alpha),
|
||||
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
|
||||
alpha
|
||||
);
|
||||
(0.0..=1.).contains(&alpha),
|
||||
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
|
||||
alpha
|
||||
);
|
||||
};
|
||||
if let Some(weights) = self.weights.as_ref() {
|
||||
assert!(
|
||||
|
|
|
@ -34,6 +34,14 @@ pub struct TransformerDecoderConfig {
|
|||
/// Layer norm will be applied first instead of after the other modules.
|
||||
#[config(default = false)]
|
||||
pub norm_first: bool,
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
///
|
||||
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
|
||||
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
|
||||
///
|
||||
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
|
||||
#[config(default = false)]
|
||||
pub quiet_softmax: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
|
||||
|
@ -186,11 +194,13 @@ impl<B: Backend> TransformerDecoderLayer<B> {
|
|||
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_quiet_softmax(config.quiet_softmax)
|
||||
.init();
|
||||
|
||||
let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_quiet_softmax(config.quiet_softmax)
|
||||
.init();
|
||||
let norm_1 = LayerNormConfig::new(config.d_model).init();
|
||||
let norm_2 = LayerNormConfig::new(config.d_model).init();
|
||||
|
@ -219,10 +229,12 @@ impl<B: Backend> TransformerDecoderLayer<B> {
|
|||
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_quiet_softmax(config.quiet_softmax)
|
||||
.init_with(record.self_attn);
|
||||
let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_quiet_softmax(config.quiet_softmax)
|
||||
.init_with(record.cross_attn);
|
||||
let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1);
|
||||
let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2);
|
||||
|
|
|
@ -34,6 +34,14 @@ pub struct TransformerEncoderConfig {
|
|||
/// Layer norm will be applied first instead of after the other modules.
|
||||
#[config(default = false)]
|
||||
pub norm_first: bool,
|
||||
/// Use "quiet softmax" instead of regular softmax.
|
||||
///
|
||||
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
|
||||
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
|
||||
///
|
||||
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
|
||||
#[config(default = false)]
|
||||
pub quiet_softmax: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
|
||||
|
@ -175,6 +183,7 @@ impl<B: Backend> TransformerEncoderLayer<B> {
|
|||
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_quiet_softmax(config.quiet_softmax)
|
||||
.init_with(record.mha);
|
||||
let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1);
|
||||
let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2);
|
||||
|
@ -197,6 +206,7 @@ impl<B: Backend> TransformerEncoderLayer<B> {
|
|||
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
|
||||
.with_initializer(config.initializer.clone())
|
||||
.with_dropout(config.dropout)
|
||||
.with_quiet_softmax(config.quiet_softmax)
|
||||
.init();
|
||||
let norm_1 = LayerNormConfig::new(config.d_model).init();
|
||||
let norm_2 = LayerNormConfig::new(config.d_model).init();
|
||||
|
|
|
@ -72,11 +72,11 @@ impl ConfigEnumAnalyzer {
|
|||
fn gen_serialize_fn(&self) -> TokenStream {
|
||||
let enum_name = self.serde_enum_ident();
|
||||
let variants = self.data.variants.iter().map(|variant| {
|
||||
let variant_name = &variant.ident;
|
||||
let (variant_input, variant_output) = self.gen_variant_field(variant);
|
||||
let variant_name = &variant.ident;
|
||||
let (variant_input, variant_output) = self.gen_variant_field(variant);
|
||||
|
||||
quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output }
|
||||
});
|
||||
quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output }
|
||||
});
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
|
@ -97,11 +97,11 @@ impl ConfigEnumAnalyzer {
|
|||
fn gen_deserialize_fn(&self) -> TokenStream {
|
||||
let enum_name = self.serde_enum_ident();
|
||||
let variants = self.data.variants.iter().map(|variant| {
|
||||
let variant_name = &variant.ident;
|
||||
let (variant_input, variant_output) = self.gen_variant_field(variant);
|
||||
let variant_name = &variant.ident;
|
||||
let (variant_input, variant_output) = self.gen_variant_field(variant);
|
||||
|
||||
quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output }
|
||||
});
|
||||
quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output }
|
||||
});
|
||||
let name = &self.name;
|
||||
|
||||
quote! {
|
||||
|
|
|
@ -23,9 +23,9 @@ impl RecordItemCodegen for StructRecordItemCodegen {
|
|||
/// Field to be serialized.
|
||||
pub #name: <#ty as burn::record::Record>::Item<S>,
|
||||
});
|
||||
bounds.extend(quote!{
|
||||
<#ty as burn::record::Record>::Item<S>: serde::Serialize + serde::de::DeserializeOwned,
|
||||
});
|
||||
bounds.extend(quote! {
|
||||
<#ty as burn::record::Record>::Item<S>: serde::Serialize + serde::de::DeserializeOwned,
|
||||
});
|
||||
}
|
||||
let bound = bounds.to_string();
|
||||
|
||||
|
|
|
@ -167,25 +167,25 @@ mod tests {
|
|||
use crate::burn::{ScalarKind, ScalarType, TensorType};
|
||||
|
||||
macro_rules! test_binary_operator_on_tensors {
|
||||
($operator:ident) => {{
|
||||
one_node_graph(
|
||||
BinaryNode::$operator(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor3", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = tensor1.$operator(tensor2);
|
||||
($operator:ident) => {{
|
||||
one_node_graph(
|
||||
BinaryNode::$operator(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor3", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = tensor1.$operator(tensor2);
|
||||
|
||||
tensor3
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["tensor3".to_string()],
|
||||
);
|
||||
}};
|
||||
}
|
||||
tensor3
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["tensor3".to_string()],
|
||||
);
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! test_binary_operator_on_tensor_and_scalar {
|
||||
($operator:ident, $burn_operator:ident) => {{
|
||||
|
|
|
@ -354,7 +354,10 @@ where
|
|||
|
||||
for i in 0..D - 1 {
|
||||
if shape_tensor.dims[i] != shape_indices.dims[i] {
|
||||
panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indices.dims);
|
||||
panic!(
|
||||
"Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}",
|
||||
shape_tensor.dims, shape_indices.dims
|
||||
);
|
||||
}
|
||||
batch_size *= shape_indices.dims[i];
|
||||
}
|
||||
|
|
|
@ -31,6 +31,26 @@ pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) ->
|
|||
tensor.div(tensor_tmp)
|
||||
}
|
||||
|
||||
/// Applies the "quiet softmax" function on the input tensor along the given dimension.
|
||||
/// This function is similar to the softmax function, but it allows for "no selection", e.g.,
|
||||
/// all outputs can tend to zero.
|
||||
///
|
||||
/// `softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The dimension argument `dim` specifies the dimension along which the function will be computed.
|
||||
/// It must in the range of `0` and `D-1`.
|
||||
pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
check!(TensorCheck::dim_ops::<D>("softmax", dim));
|
||||
|
||||
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
|
||||
let tensor = tensor.exp();
|
||||
let tensor_tmp = tensor.clone().sum_dim(dim);
|
||||
|
||||
tensor.div(tensor_tmp + 1)
|
||||
}
|
||||
|
||||
/// Applies the log softmax function on the input tensor along the given dimension.
|
||||
///
|
||||
/// `log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`
|
||||
|
|
|
@ -128,13 +128,16 @@ impl TensorCheck {
|
|||
let mut check = Self::Ok;
|
||||
|
||||
if original.num_elements() != target.num_elements() {
|
||||
check = check.register("Reshape", TensorError::new(
|
||||
"The given shape doesn't have the same number of elements as the current tensor.",
|
||||
)
|
||||
.details(format!(
|
||||
"Current shape: {:?}, target shape: {:?}.",
|
||||
original.dims, target.dims
|
||||
)));
|
||||
check = check.register(
|
||||
"Reshape",
|
||||
TensorError::new(
|
||||
"The given shape doesn't have the same number of elements as the current tensor.",
|
||||
)
|
||||
.details(format!(
|
||||
"Current shape: {:?}, target shape: {:?}.",
|
||||
original.dims, target.dims
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
check
|
||||
|
@ -307,8 +310,8 @@ impl TensorCheck {
|
|||
check = check.register(
|
||||
"Matmul",
|
||||
TensorError::new(format!(
|
||||
"The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}."
|
||||
))
|
||||
"The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}."
|
||||
))
|
||||
.details(format!(
|
||||
"Lhs shape {:?}, rhs shape {:?}.",
|
||||
shape_lhs.dims, shape_rhs.dims
|
||||
|
@ -399,15 +402,16 @@ impl TensorCheck {
|
|||
|
||||
if shape_reference != shape {
|
||||
return check.register(
|
||||
"Cat",
|
||||
TensorError::new("Can't concatenate tensors with different shapes, except for the provided dimension").details(
|
||||
format!(
|
||||
"Provided dimension ({}), tensors shapes: {:?}",
|
||||
dim,
|
||||
tensors.iter().map(Tensor::shape).collect::<Vec<_>>()
|
||||
),
|
||||
),
|
||||
);
|
||||
"Cat",
|
||||
TensorError::new(
|
||||
"Can't concatenate tensors with different shapes, except for the provided dimension",
|
||||
)
|
||||
.details(format!(
|
||||
"Provided dimension ({}), tensors shapes: {:?}",
|
||||
dim,
|
||||
tensors.iter().map(Tensor::shape).collect::<Vec<_>>()
|
||||
)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -438,18 +442,16 @@ impl TensorCheck {
|
|||
|
||||
if range.end > d_tensor {
|
||||
check = check.register(
|
||||
"Slice",
|
||||
TensorError::new("The provided ranges array has a range that exceeds the current tensor size.")
|
||||
.details(format!(
|
||||
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
|
||||
"Slice",
|
||||
TensorError::new(
|
||||
"The provided ranges array has a range that exceeds the current tensor size.",
|
||||
)
|
||||
.details(format!(
|
||||
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
|
||||
Tensor shape {:?}, provided ranges {:?}.",
|
||||
range.start,
|
||||
range.end,
|
||||
d_tensor,
|
||||
i,
|
||||
shape.dims,
|
||||
ranges,
|
||||
)));
|
||||
range.start, range.end, d_tensor, i, shape.dims, ranges,
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
if range.start >= range.end {
|
||||
|
@ -479,13 +481,16 @@ impl TensorCheck {
|
|||
let mut check = Self::Ok;
|
||||
|
||||
if D1 < D2 {
|
||||
check = check.register("Slice Assign",
|
||||
TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.")
|
||||
.details(
|
||||
format!(
|
||||
"The ranges array must be smaller or equal to the tensor number of dimensions. \
|
||||
check = check.register(
|
||||
"Slice Assign",
|
||||
TensorError::new(
|
||||
"The provided ranges array has a higher number of dimensions than the current tensor.",
|
||||
)
|
||||
.details(format!(
|
||||
"The ranges array must be smaller or equal to the tensor number of dimensions. \
|
||||
Tensor number of dimensions: {D1}, ranges array length {D2}."
|
||||
)));
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
for i in 0..usize::min(D1, D2) {
|
||||
|
@ -495,19 +500,16 @@ impl TensorCheck {
|
|||
|
||||
if range.end > d_tensor {
|
||||
check = check.register(
|
||||
"Range Assign",
|
||||
TensorError::new("The provided ranges array has a range that exceeds the current tensor size.")
|
||||
.details(format!(
|
||||
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
|
||||
"Range Assign",
|
||||
TensorError::new(
|
||||
"The provided ranges array has a range that exceeds the current tensor size.",
|
||||
)
|
||||
.details(format!(
|
||||
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
|
||||
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
|
||||
range.start,
|
||||
range.end,
|
||||
d_tensor,
|
||||
i,
|
||||
shape.dims,
|
||||
shape_value.dims,
|
||||
ranges,
|
||||
)));
|
||||
range.start, range.end, d_tensor, i, shape.dims, shape_value.dims, ranges,
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
if range.end - range.start != d_tensor_value {
|
||||
|
@ -697,17 +699,16 @@ impl TensorCheck {
|
|||
continue;
|
||||
}
|
||||
|
||||
check = check.register(ops,
|
||||
TensorError::new("The provided tensors have incompatible shapes.")
|
||||
.details(format!(
|
||||
"Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \
|
||||
check = check.register(
|
||||
ops,
|
||||
TensorError::new("The provided tensors have incompatible shapes.").details(
|
||||
format!(
|
||||
"Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \
|
||||
Lhs tensor shape {:?}, Rhs tensor shape {:?}.",
|
||||
i,
|
||||
d_lhs,
|
||||
d_rhs,
|
||||
lhs.dims,
|
||||
rhs.dims,
|
||||
)));
|
||||
i, d_lhs, d_rhs, lhs.dims, rhs.dims,
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -320,10 +320,9 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
|
|||
if err > tolerance {
|
||||
// Only print the first 5 different values.
|
||||
if num_diff < max_num_diff {
|
||||
message += format!(
|
||||
"\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}"
|
||||
)
|
||||
.as_str();
|
||||
message +=
|
||||
format!("\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}")
|
||||
.as_str();
|
||||
}
|
||||
num_diff += 1;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
#[burn_tensor_testgen::testgen(quiet_softmax)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_quiet_softmax_d2() {
|
||||
let data = Data::from([[1.0, 7.0], [13.0, -3.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = activation::quiet_softmax(tensor, 1).to_data();
|
||||
|
||||
let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]);
|
||||
data_actual.assert_approx_eq(&data_expected, 4);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue