diff --git a/burn-ndarray/src/ops/tensor.rs b/burn-ndarray/src/ops/tensor.rs index 592b94ab2..1c475500a 100644 --- a/burn-ndarray/src/ops/tensor.rs +++ b/burn-ndarray/src/ops/tensor.rs @@ -6,7 +6,7 @@ use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend}; use crate::{to_nd_array_tensor, NdArrayDevice, SEED}; use burn_tensor::Distribution; use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape}; -use ndarray::{Axis, Dim, IxDyn, SliceInfoElem}; +use ndarray::{ArcArray, Axis, Dim, Dimension, IxDyn, SliceInfoElem}; use rand::rngs::StdRng; use rand::SeedableRng; @@ -527,14 +527,12 @@ impl TensorOps> for NdArrayBackend { } fn cat(tensors: &[NdArrayTensor], dim: usize) -> NdArrayTensor { - let mut shape = tensors.get(0).unwrap().shape; - shape.dims[dim] = tensors.len(); - let arrays: Vec> = tensors.iter().map(|t| t.array.view()).collect(); let array = ndarray::concatenate(Axis(dim), &arrays) .unwrap() .into_shared(); + let shape = array_shape(&array); NdArrayTensor { array, shape } } @@ -646,3 +644,13 @@ fn cmp_min(a: &f64, b: &f64) -> Ordering { } Ordering::Equal } + +fn array_shape(array: &ArcArray) -> Shape { + let dims = array + .raw_dim() + .slice() + .iter() + .map(|a| *a as i64) + .collect::>(); + Shape::from(dims) +} diff --git a/burn-tensor/src/tensor/base.rs b/burn-tensor/src/tensor/base.rs index a01e51584..707b86a80 100644 --- a/burn-tensor/src/tensor/base.rs +++ b/burn-tensor/src/tensor/base.rs @@ -399,6 +399,12 @@ where Self::new(tensor) } + /// Create a tensor of the given shape where each element is one. + pub fn ones_device>>(shape: S, device: B::Device) -> Self { + let tensor = B::ones(shape.into(), device); + Self::new(tensor) + } + /// Returns a tensor containing the elements selected from the given ranges. /// /// # Panics diff --git a/burn/src/nn/attention/mask.rs b/burn/src/nn/attention/mask.rs new file mode 100644 index 000000000..d67777909 --- /dev/null +++ b/burn/src/nn/attention/mask.rs @@ -0,0 +1,50 @@ +use burn_tensor::{backend::Backend, BoolTensor, ElementConversion, Tensor}; + +/// Generate an autoregressive attention mask. +/// +/// The mask can be used in Transformer modules to train models to generate tensors sequentially. +pub fn generate_autoregressive_mask( + batch_size: usize, + seq_length: usize, + device: B::Device, +) -> BoolTensor { + let mut mask = Tensor::::zeros([1, seq_length, seq_length]); + + for i in 0..seq_length { + let values = Tensor::::ones([1, 1, seq_length - (i + 1)]); + mask = mask.index_assign([0..1, i..i + 1, i + 1..seq_length], &values); + } + + mask = mask.to_device(device).repeat(0, batch_size); + + BoolTensor::from_int_backend(mask.equal_scalar(1_i64.to_elem::())) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + use burn_tensor::Data; + + #[test] + fn test_generate_autoregressive_mask() { + let device = ::Device::default(); + let mask = generate_autoregressive_mask::(2, 3, device); + + assert_eq!( + mask.into_data(), + Data::from([ + [ + [false, true, true], + [false, false, true], + [false, false, false], + ], + [ + [false, true, true], + [false, false, true], + [false, false, false], + ] + ]) + ); + } +} diff --git a/burn/src/nn/attention/mha.rs b/burn/src/nn/attention/mha.rs index d6a17c4ad..0b44349c9 100644 --- a/burn/src/nn/attention/mha.rs +++ b/burn/src/nn/attention/mha.rs @@ -1,5 +1,6 @@ use crate as burn; +use crate::nn::cache::TensorCache; use crate::{ config::Config, module::{Module, Param}, @@ -46,7 +47,7 @@ pub struct MultiHeadAttention { } /// [Multihead attention](MultiHeadAttention) forward pass input argument. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MhaInput { query: Tensor, key: Tensor, @@ -93,7 +94,7 @@ impl MhaInput { } /// [Multihead attention](MultiHeadAttention) outputs. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MhaOutput { /// The attention weights [batch_size, seq_length_1, seq_length_2]. pub weights: Tensor, @@ -151,6 +152,51 @@ impl MultiHeadAttention { MhaOutput { weights, context } } + /// Applies the forward pass on the input tensors using an autoregressive cache. + /// + /// # Shapes + /// + /// - query: `[batch_size, seq_length_1, d_model]` + /// - key: `[batch_size, seq_length_2, d_model]` + /// - value: `[batch_size, seq_length_2, d_model]` + /// - output: `[batch_size, seq_length_1, d_model]` + pub fn forward_autoregressive_inference( + &self, + input: MhaInput, + cache: &mut MHAAutoregressiveCache, + ) -> MhaOutput { + let [batch_size, seq_length_1, d_model] = input.query.dims(); + + let attention_linear = |cache: &mut TensorCache, + tensor: Tensor, + param: &Param>| { + cache.forward_autoregressive(tensor, 2, |tensor| self.attention_linear(tensor, param)) + }; + + let query = attention_linear(&mut cache.query, input.query, &self.query); + let key = attention_linear(&mut cache.key, input.key, &self.key); + let value = attention_linear(&mut cache.value, input.value, &self.value); + + let attn_scores = self.attn_scores(query, key); + let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn); + + let context = weights.matmul(&value); + let context = context + .swap_dims(1, 2) + .reshape([batch_size, seq_length_1, d_model]); + + let context = cache + .output + .forward_autoregressive(context, 1, |context| self.output.forward(context)); + + MhaOutput { weights, context } + } + + /// Create an empty autoregressive cache. + pub fn new_autoregressive_cache(&self) -> MHAAutoregressiveCache { + MHAAutoregressiveCache::default() + } + fn attn_scores(&self, query: Tensor, key: Tensor) -> Tensor { let attn_scores = query .matmul(&key.transpose()) @@ -195,10 +241,21 @@ impl MultiHeadAttention { } } +/// Autoregressive cache for the [Multi Head Attention](MultiHeadAttention) layer. +/// +/// To be used during inference when decoding tokens. +#[derive(Default)] +pub struct MHAAutoregressiveCache { + query: TensorCache, + key: TensorCache, + value: TensorCache, + output: TensorCache, +} + #[cfg(test)] mod tests { use super::*; - use crate::TestBackend; + use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; use burn::tensor::{Distribution, Shape}; #[test] @@ -253,7 +310,7 @@ mod tests { } #[test] - pub fn test_self_attention_mask_pad() { + fn test_self_attention_mask_pad() { let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2]; let mha = MultiHeadAttention::new(&MultiHeadAttentionConfig::new(d_model, n_heads)); @@ -298,4 +355,38 @@ mod tests { 3, ); } + + #[test] + fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() { + let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2]; + let mha = MultiHeadAttention::new(&MultiHeadAttentionConfig::new(d_model, n_heads)); + + let tensor = Tensor::::random( + [batch_size, seq_length, d_model], + Distribution::Standard, + ); + let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device()); + let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn); + + let output_1 = mha.forward(input); + let mut output_2 = Vec::new(); + let mut cache = mha.new_autoregressive_cache(); + + for i in 1..seq_length + 1 { + let tensor = tensor.index([0..batch_size, 0..i, 0..d_model]); + let input = MhaInput::self_attn(tensor); + let next_tok = mha + .forward_autoregressive_inference(input, &mut cache) + .context + .index([0..batch_size, i - 1..i, 0..d_model]); + output_2.push(next_tok); + } + + let output_2 = Tensor::cat(output_2, 1); + + output_1 + .context + .into_data() + .assert_approx_eq(&output_2.into_data(), 3); + } } diff --git a/burn/src/nn/attention/mod.rs b/burn/src/nn/attention/mod.rs index e92d1621c..783379f48 100644 --- a/burn/src/nn/attention/mod.rs +++ b/burn/src/nn/attention/mod.rs @@ -1,3 +1,5 @@ +mod mask; mod mha; +pub use mask::*; pub use mha::*; diff --git a/burn/src/nn/cache/autoregressive.rs b/burn/src/nn/cache/autoregressive.rs new file mode 100644 index 000000000..3ed26bca2 --- /dev/null +++ b/burn/src/nn/cache/autoregressive.rs @@ -0,0 +1,33 @@ +use super::TensorCache; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; + +impl TensorCache { + pub(crate) fn forward_autoregressive( + &mut self, + tensor: Tensor, + dim_cat: usize, + func: F, + ) -> Tensor + where + F: Fn(Tensor) -> Tensor, + { + let mut tensor_old = None; + std::mem::swap(&mut self.state, &mut tensor_old); + + let tensor_new = match tensor_old { + Some(tensor_old) => { + let [batch_size, seq_length, d_model] = tensor.dims(); + let next_seq_token = + tensor.index([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]); + let next_seq_token = func(next_seq_token); + + Tensor::cat(vec![tensor_old, next_seq_token], dim_cat) + } + None => func(tensor), + }; + + self.state = Some(tensor_new.clone()); + tensor_new + } +} diff --git a/burn/src/nn/cache/base.rs b/burn/src/nn/cache/base.rs new file mode 100644 index 000000000..73726ea27 --- /dev/null +++ b/burn/src/nn/cache/base.rs @@ -0,0 +1,13 @@ +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; + +#[derive(Default)] +pub struct TensorCache { + pub(crate) state: Option>, +} + +impl TensorCache { + pub fn new() -> Self { + Self::default() + } +} diff --git a/burn/src/nn/cache/mod.rs b/burn/src/nn/cache/mod.rs new file mode 100644 index 000000000..39f532180 --- /dev/null +++ b/burn/src/nn/cache/mod.rs @@ -0,0 +1,5 @@ +mod autoregressive; +mod base; + +pub use autoregressive::*; +pub use base::*; diff --git a/burn/src/nn/mod.rs b/burn/src/nn/mod.rs index fd15e5a7a..c5c64787e 100644 --- a/burn/src/nn/mod.rs +++ b/burn/src/nn/mod.rs @@ -1,4 +1,5 @@ pub mod attention; +pub mod cache; pub mod transformer; mod dropout; diff --git a/burn/src/nn/transformer/encoder.rs b/burn/src/nn/transformer/encoder.rs index c10e40133..9fe493720 100644 --- a/burn/src/nn/transformer/encoder.rs +++ b/burn/src/nn/transformer/encoder.rs @@ -1,4 +1,7 @@ -use crate as burn; +use crate::{ + self as burn, + nn::{attention::MHAAutoregressiveCache, cache::TensorCache}, +}; use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ @@ -96,6 +99,38 @@ impl TransformerEncoder { x } + /// Applies the forward pass on the input tensor using autoregressive cache. + /// + /// # Shapes + /// + /// - tensor: `[batch_size, seq_length, d_model]` + /// - output: `[batch_size, seq_length, d_model]` + pub fn forward_autoregressive_inference( + &self, + input: TransformerEncoderInput, + cache: &mut TransformerEncoderAutoregressiveCache, + ) -> Tensor { + let mut x = input.tensor; + + for i in 0..self.layers.len() { + let layer = self.layers.get(i).unwrap(); + let cache = cache.layers.get_mut(i).unwrap(); + + x = layer.forward_autoregressive_inference( + x, + input.mask_pad.clone(), + input.mask_attn.clone(), + cache, + ); + } + + x + } + + /// Create an empty autoregressive cache. + pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache { + TransformerEncoderAutoregressiveCache::empty(self.layers.len()) + } } #[derive(Module, Debug)] @@ -156,4 +191,111 @@ impl TransformerEncoderLayer { self.norm_2.forward(x_2) } + + fn forward_autoregressive_inference( + &self, + input: Tensor, + mask_pad: Option>, + mask_attn: Option>, + cache: &mut TransformerEncoderLayerAutoregressiveCache, + ) -> Tensor { + let mut input_mhs = MhaInput::self_attn(input.clone()); + + if let Some(mask_pad) = mask_pad { + input_mhs = input_mhs.mask_pad(mask_pad); + } + + if let Some(mask_attn) = mask_attn { + input_mhs = input_mhs.mask_attn(mask_attn); + } + + let x_1 = self + .mha + .forward_autoregressive_inference(input_mhs, &mut cache.mha); + let x_1 = self.dropout.forward(x_1.context) + input; + let x_1 = cache + .norm_1 + .forward_autoregressive(x_1, 1, |x_1| self.norm_1.forward(x_1)); + + let x_2 = cache + .pwff + .forward_autoregressive(x_1.clone(), 1, |x_1| self.pwff.forward(x_1)); + let x_2 = self.dropout.forward(x_2) + x_1; + + cache + .norm_2 + .forward_autoregressive(x_2, 1, |x_2| self.norm_2.forward(x_2)) + } +} + +#[derive(Default)] +struct TransformerEncoderLayerAutoregressiveCache { + mha: MHAAutoregressiveCache, + pwff: TensorCache, + norm_1: TensorCache, + norm_2: TensorCache, +} + +impl TransformerEncoderLayerAutoregressiveCache { + fn new() -> Self { + Self::default() + } +} + +/// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer. +/// +/// To be used during inference when decoding tokens. +pub struct TransformerEncoderAutoregressiveCache { + layers: Vec>, +} + +impl TransformerEncoderAutoregressiveCache { + fn empty(num_layers: usize) -> Self { + Self { + layers: (0..num_layers) + .map(|_| TransformerEncoderLayerAutoregressiveCache::new()) + .collect(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; + use burn_tensor::Distribution; + + #[test] + fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() { + let [batch_size, seq_length, d_model, d_ff, n_heads, num_layers] = [3, 4, 12, 24, 2, 3]; + let transformer = TransformerEncoder::new(&TransformerEncoderConfig::new( + d_model, d_ff, n_heads, num_layers, + )); + + let tensor = Tensor::::random( + [batch_size, seq_length, d_model], + Distribution::Standard, + ); + let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device()); + let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn); + + let output_1 = transformer.forward(input); + let mut output_2 = Vec::new(); + let mut cache = transformer.new_autoregressive_cache(); + + for i in 1..seq_length + 1 { + let tensor = tensor.index([0..batch_size, 0..i, 0..d_model]); + let input = TransformerEncoderInput::new(tensor.clone()); + let next_tok = transformer + .forward_autoregressive_inference(input, &mut cache) + .index([0..batch_size, i - 1..i, 0..d_model]); + output_2.push(next_tok); + } + + let output_2 = Tensor::cat(output_2, 1); + + output_1 + .into_data() + .assert_approx_eq(&output_2.into_data(), 3); + } }