Implement Quiet Softmax (`Attention Is Off By One`) (#692)

* Added quiet_softmax

* Undid bad formatting

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
This commit is contained in:
Will Brickner 2023-11-30 11:58:30 -06:00 committed by GitHub
parent f0c75aa748
commit 03af140e12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 199 additions and 100 deletions

View File

@ -46,4 +46,27 @@ mod tests {
.to_data()
.assert_approx_eq(&Data::from([[30.5984, -47.2267], [55.9631, -56.5914]]), 3);
}
#[test]
fn test_quiet_softmax_grad() {
let data_1 = Data::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::from([[6.0, 7.0], [9.0, 10.0]]);
let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1).require_grad();
let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2).require_grad();
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
grad_1
.to_data()
.assert_approx_eq(&Data::from([[1.1797, 1.1797], [0.0055, 0.0055]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[0.2534, 0.2862], [0.5286, 2.9317]]), 3);
}
}

View File

@ -25,6 +25,14 @@ pub struct MultiHeadAttentionConfig {
/// A value too low might result in NaN.
#[config(default = -1.0e4)]
min_float: f64,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
quiet_softmax: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
@ -51,6 +59,7 @@ pub struct MultiHeadAttention<B: Backend> {
n_heads: usize,
d_k: usize,
min_float: f64,
quiet_softmax: bool,
}
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
@ -82,6 +91,7 @@ impl MultiHeadAttentionConfig {
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
}
}
@ -105,6 +115,7 @@ impl MultiHeadAttentionConfig {
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
}
}
}
@ -249,7 +260,11 @@ impl<B: Backend> MultiHeadAttention<B> {
);
}
activation::softmax(attn_scores, 3)
if self.quiet_softmax {
activation::quiet_softmax(attn_scores, 3)
} else {
activation::softmax(attn_scores, 3)
}
}
fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {

View File

@ -45,10 +45,10 @@ impl BinaryCrossEntropyLossConfig {
fn assertions(&self) {
if let Some(alpha) = self.smoothing {
assert!(
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
};
if let Some(weights) = self.weights.as_ref() {
assert!(

View File

@ -53,10 +53,10 @@ impl CrossEntropyLossConfig {
fn assertions(&self) {
if let Some(alpha) = self.smoothing {
assert!(
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
};
if let Some(weights) = self.weights.as_ref() {
assert!(

View File

@ -34,6 +34,14 @@ pub struct TransformerDecoderConfig {
/// Layer norm will be applied first instead of after the other modules.
#[config(default = false)]
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
pub quiet_softmax: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
@ -186,11 +194,13 @@ impl<B: Backend> TransformerDecoderLayer<B> {
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init();
let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init();
let norm_1 = LayerNormConfig::new(config.d_model).init();
let norm_2 = LayerNormConfig::new(config.d_model).init();
@ -219,10 +229,12 @@ impl<B: Backend> TransformerDecoderLayer<B> {
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init_with(record.self_attn);
let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init_with(record.cross_attn);
let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1);
let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2);

View File

@ -34,6 +34,14 @@ pub struct TransformerEncoderConfig {
/// Layer norm will be applied first instead of after the other modules.
#[config(default = false)]
pub norm_first: bool,
/// Use "quiet softmax" instead of regular softmax.
///
/// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).
/// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.
///
/// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>
#[config(default = false)]
pub quiet_softmax: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/libm::sqrt(3.0), fan_out_only:false}"
@ -175,6 +183,7 @@ impl<B: Backend> TransformerEncoderLayer<B> {
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init_with(record.mha);
let norm_1 = LayerNormConfig::new(config.d_model).init_with(record.norm_1);
let norm_2 = LayerNormConfig::new(config.d_model).init_with(record.norm_2);
@ -197,6 +206,7 @@ impl<B: Backend> TransformerEncoderLayer<B> {
let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init();
let norm_1 = LayerNormConfig::new(config.d_model).init();
let norm_2 = LayerNormConfig::new(config.d_model).init();

View File

@ -72,11 +72,11 @@ impl ConfigEnumAnalyzer {
fn gen_serialize_fn(&self) -> TokenStream {
let enum_name = self.serde_enum_ident();
let variants = self.data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let (variant_input, variant_output) = self.gen_variant_field(variant);
let variant_name = &variant.ident;
let (variant_input, variant_output) = self.gen_variant_field(variant);
quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output }
});
quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output }
});
let name = &self.name;
quote! {
@ -97,11 +97,11 @@ impl ConfigEnumAnalyzer {
fn gen_deserialize_fn(&self) -> TokenStream {
let enum_name = self.serde_enum_ident();
let variants = self.data.variants.iter().map(|variant| {
let variant_name = &variant.ident;
let (variant_input, variant_output) = self.gen_variant_field(variant);
let variant_name = &variant.ident;
let (variant_input, variant_output) = self.gen_variant_field(variant);
quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output }
});
quote! { #enum_name::#variant_name #variant_input => Self::#variant_name #variant_output }
});
let name = &self.name;
quote! {

View File

@ -23,9 +23,9 @@ impl RecordItemCodegen for StructRecordItemCodegen {
/// Field to be serialized.
pub #name: <#ty as burn::record::Record>::Item<S>,
});
bounds.extend(quote!{
<#ty as burn::record::Record>::Item<S>: serde::Serialize + serde::de::DeserializeOwned,
});
bounds.extend(quote! {
<#ty as burn::record::Record>::Item<S>: serde::Serialize + serde::de::DeserializeOwned,
});
}
let bound = bounds.to_string();

View File

@ -167,25 +167,25 @@ mod tests {
use crate::burn::{ScalarKind, ScalarType, TensorType};
macro_rules! test_binary_operator_on_tensors {
($operator:ident) => {{
one_node_graph(
BinaryNode::$operator(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Type::Tensor(TensorType::new_float("tensor3", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.$operator(tensor2);
($operator:ident) => {{
one_node_graph(
BinaryNode::$operator(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Type::Tensor(TensorType::new_float("tensor3", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor3 = tensor1.$operator(tensor2);
tensor3
}
},
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
}};
}
tensor3
}
},
vec!["tensor1".to_string(), "tensor2".to_string()],
vec!["tensor3".to_string()],
);
}};
}
macro_rules! test_binary_operator_on_tensor_and_scalar {
($operator:ident, $burn_operator:ident) => {{

View File

@ -354,7 +354,10 @@ where
for i in 0..D - 1 {
if shape_tensor.dims[i] != shape_indices.dims[i] {
panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indices.dims);
panic!(
"Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}",
shape_tensor.dims, shape_indices.dims
);
}
batch_size *= shape_indices.dims[i];
}

View File

@ -31,6 +31,26 @@ pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) ->
tensor.div(tensor_tmp)
}
/// Applies the "quiet softmax" function on the input tensor along the given dimension.
/// This function is similar to the softmax function, but it allows for "no selection", e.g.,
/// all outputs can tend to zero.
///
/// `softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`
///
/// # Notes
///
/// The dimension argument `dim` specifies the dimension along which the function will be computed.
/// It must in the range of `0` and `D-1`.
pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmax", dim));
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
let tensor = tensor.exp();
let tensor_tmp = tensor.clone().sum_dim(dim);
tensor.div(tensor_tmp + 1)
}
/// Applies the log softmax function on the input tensor along the given dimension.
///
/// `log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`

View File

@ -128,13 +128,16 @@ impl TensorCheck {
let mut check = Self::Ok;
if original.num_elements() != target.num_elements() {
check = check.register("Reshape", TensorError::new(
"The given shape doesn't have the same number of elements as the current tensor.",
)
.details(format!(
"Current shape: {:?}, target shape: {:?}.",
original.dims, target.dims
)));
check = check.register(
"Reshape",
TensorError::new(
"The given shape doesn't have the same number of elements as the current tensor.",
)
.details(format!(
"Current shape: {:?}, target shape: {:?}.",
original.dims, target.dims
)),
);
}
check
@ -307,8 +310,8 @@ impl TensorCheck {
check = check.register(
"Matmul",
TensorError::new(format!(
"The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}."
))
"The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}."
))
.details(format!(
"Lhs shape {:?}, rhs shape {:?}.",
shape_lhs.dims, shape_rhs.dims
@ -399,15 +402,16 @@ impl TensorCheck {
if shape_reference != shape {
return check.register(
"Cat",
TensorError::new("Can't concatenate tensors with different shapes, except for the provided dimension").details(
format!(
"Provided dimension ({}), tensors shapes: {:?}",
dim,
tensors.iter().map(Tensor::shape).collect::<Vec<_>>()
),
),
);
"Cat",
TensorError::new(
"Can't concatenate tensors with different shapes, except for the provided dimension",
)
.details(format!(
"Provided dimension ({}), tensors shapes: {:?}",
dim,
tensors.iter().map(Tensor::shape).collect::<Vec<_>>()
)),
);
}
}
@ -438,18 +442,16 @@ impl TensorCheck {
if range.end > d_tensor {
check = check.register(
"Slice",
TensorError::new("The provided ranges array has a range that exceeds the current tensor size.")
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
"Slice",
TensorError::new(
"The provided ranges array has a range that exceeds the current tensor size.",
)
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Tensor shape {:?}, provided ranges {:?}.",
range.start,
range.end,
d_tensor,
i,
shape.dims,
ranges,
)));
range.start, range.end, d_tensor, i, shape.dims, ranges,
)),
);
}
if range.start >= range.end {
@ -479,13 +481,16 @@ impl TensorCheck {
let mut check = Self::Ok;
if D1 < D2 {
check = check.register("Slice Assign",
TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.")
.details(
format!(
"The ranges array must be smaller or equal to the tensor number of dimensions. \
check = check.register(
"Slice Assign",
TensorError::new(
"The provided ranges array has a higher number of dimensions than the current tensor.",
)
.details(format!(
"The ranges array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {D1}, ranges array length {D2}."
)));
)),
);
}
for i in 0..usize::min(D1, D2) {
@ -495,19 +500,16 @@ impl TensorCheck {
if range.end > d_tensor {
check = check.register(
"Range Assign",
TensorError::new("The provided ranges array has a range that exceeds the current tensor size.")
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
"Range Assign",
TensorError::new(
"The provided ranges array has a range that exceeds the current tensor size.",
)
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
range.start,
range.end,
d_tensor,
i,
shape.dims,
shape_value.dims,
ranges,
)));
range.start, range.end, d_tensor, i, shape.dims, shape_value.dims, ranges,
)),
);
}
if range.end - range.start != d_tensor_value {
@ -697,17 +699,16 @@ impl TensorCheck {
continue;
}
check = check.register(ops,
TensorError::new("The provided tensors have incompatible shapes.")
.details(format!(
"Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \
check = check.register(
ops,
TensorError::new("The provided tensors have incompatible shapes.").details(
format!(
"Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \
Lhs tensor shape {:?}, Rhs tensor shape {:?}.",
i,
d_lhs,
d_rhs,
lhs.dims,
rhs.dims,
)));
i, d_lhs, d_rhs, lhs.dims, rhs.dims,
),
),
);
}
}

View File

@ -320,10 +320,9 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
if err > tolerance {
// Only print the first 5 different values.
if num_diff < max_num_diff {
message += format!(
"\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}"
)
.as_str();
message +=
format!("\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}")
.as_str();
}
num_diff += 1;
}

View File

@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(quiet_softmax)]
mod tests {
use super::*;
use burn_tensor::{activation, Data, Tensor};
#[test]
fn test_quiet_softmax_d2() {
let data = Data::from([[1.0, 7.0], [13.0, -3.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = activation::quiet_softmax(tensor, 1).to_data();
let data_expected = Data::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]);
data_actual.assert_approx_eq(&data_expected, 4);
}
}