mirror of https://github.com/tracel-ai/burn.git
Feat/autoregressive transformer (#125)
This commit is contained in:
parent
b99b23e1a7
commit
63d8d39517
|
@ -6,7 +6,7 @@ use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
|||
use crate::{to_nd_array_tensor, NdArrayDevice, SEED};
|
||||
use burn_tensor::Distribution;
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
|
||||
use ndarray::{Axis, Dim, IxDyn, SliceInfoElem};
|
||||
use ndarray::{ArcArray, Axis, Dim, Dimension, IxDyn, SliceInfoElem};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
|
||||
|
@ -527,14 +527,12 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
}
|
||||
|
||||
fn cat<const D: usize>(tensors: &[NdArrayTensor<E, D>], dim: usize) -> NdArrayTensor<E, D> {
|
||||
let mut shape = tensors.get(0).unwrap().shape;
|
||||
shape.dims[dim] = tensors.len();
|
||||
|
||||
let arrays: Vec<ndarray::ArrayView<E, IxDyn>> =
|
||||
tensors.iter().map(|t| t.array.view()).collect();
|
||||
let array = ndarray::concatenate(Axis(dim), &arrays)
|
||||
.unwrap()
|
||||
.into_shared();
|
||||
let shape = array_shape(&array);
|
||||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
|
@ -646,3 +644,13 @@ fn cmp_min(a: &f64, b: &f64) -> Ordering {
|
|||
}
|
||||
Ordering::Equal
|
||||
}
|
||||
|
||||
fn array_shape<const D: usize, E>(array: &ArcArray<E, IxDyn>) -> Shape<D> {
|
||||
let dims = array
|
||||
.raw_dim()
|
||||
.slice()
|
||||
.iter()
|
||||
.map(|a| *a as i64)
|
||||
.collect::<Vec<_>>();
|
||||
Shape::from(dims)
|
||||
}
|
||||
|
|
|
@ -399,6 +399,12 @@ where
|
|||
Self::new(tensor)
|
||||
}
|
||||
|
||||
/// Create a tensor of the given shape where each element is one.
|
||||
pub fn ones_device<S: Into<Shape<D>>>(shape: S, device: B::Device) -> Self {
|
||||
let tensor = B::ones(shape.into(), device);
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
||||
/// Returns a tensor containing the elements selected from the given ranges.
|
||||
///
|
||||
/// # Panics
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
use burn_tensor::{backend::Backend, BoolTensor, ElementConversion, Tensor};
|
||||
|
||||
/// Generate an autoregressive attention mask.
|
||||
///
|
||||
/// The mask can be used in Transformer modules to train models to generate tensors sequentially.
|
||||
pub fn generate_autoregressive_mask<B: Backend>(
|
||||
batch_size: usize,
|
||||
seq_length: usize,
|
||||
device: B::Device,
|
||||
) -> BoolTensor<B, 3> {
|
||||
let mut mask = Tensor::<B::IntegerBackend, 3>::zeros([1, seq_length, seq_length]);
|
||||
|
||||
for i in 0..seq_length {
|
||||
let values = Tensor::<B::IntegerBackend, 3>::ones([1, 1, seq_length - (i + 1)]);
|
||||
mask = mask.index_assign([0..1, i..i + 1, i + 1..seq_length], &values);
|
||||
}
|
||||
|
||||
mask = mask.to_device(device).repeat(0, batch_size);
|
||||
|
||||
BoolTensor::from_int_backend(mask.equal_scalar(1_i64.to_elem::<i64>()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn test_generate_autoregressive_mask() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
let mask = generate_autoregressive_mask::<TestBackend>(2, 3, device);
|
||||
|
||||
assert_eq!(
|
||||
mask.into_data(),
|
||||
Data::from([
|
||||
[
|
||||
[false, true, true],
|
||||
[false, false, true],
|
||||
[false, false, false],
|
||||
],
|
||||
[
|
||||
[false, true, true],
|
||||
[false, false, true],
|
||||
[false, false, false],
|
||||
]
|
||||
])
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::nn::cache::TensorCache;
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
|
@ -46,7 +47,7 @@ pub struct MultiHeadAttention<B: Backend> {
|
|||
}
|
||||
|
||||
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MhaInput<B: Backend> {
|
||||
query: Tensor<B, 3>,
|
||||
key: Tensor<B, 3>,
|
||||
|
@ -93,7 +94,7 @@ impl<B: Backend> MhaInput<B> {
|
|||
}
|
||||
|
||||
/// [Multihead attention](MultiHeadAttention) outputs.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MhaOutput<B: Backend> {
|
||||
/// The attention weights [batch_size, seq_length_1, seq_length_2].
|
||||
pub weights: Tensor<B, 4>,
|
||||
|
@ -151,6 +152,51 @@ impl<B: Backend> MultiHeadAttention<B> {
|
|||
MhaOutput { weights, context }
|
||||
}
|
||||
|
||||
/// Applies the forward pass on the input tensors using an autoregressive cache.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - query: `[batch_size, seq_length_1, d_model]`
|
||||
/// - key: `[batch_size, seq_length_2, d_model]`
|
||||
/// - value: `[batch_size, seq_length_2, d_model]`
|
||||
/// - output: `[batch_size, seq_length_1, d_model]`
|
||||
pub fn forward_autoregressive_inference(
|
||||
&self,
|
||||
input: MhaInput<B>,
|
||||
cache: &mut MHAAutoregressiveCache<B>,
|
||||
) -> MhaOutput<B> {
|
||||
let [batch_size, seq_length_1, d_model] = input.query.dims();
|
||||
|
||||
let attention_linear = |cache: &mut TensorCache<B, 4>,
|
||||
tensor: Tensor<B, 3>,
|
||||
param: &Param<nn::Linear<B>>| {
|
||||
cache.forward_autoregressive(tensor, 2, |tensor| self.attention_linear(tensor, param))
|
||||
};
|
||||
|
||||
let query = attention_linear(&mut cache.query, input.query, &self.query);
|
||||
let key = attention_linear(&mut cache.key, input.key, &self.key);
|
||||
let value = attention_linear(&mut cache.value, input.value, &self.value);
|
||||
|
||||
let attn_scores = self.attn_scores(query, key);
|
||||
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
|
||||
|
||||
let context = weights.matmul(&value);
|
||||
let context = context
|
||||
.swap_dims(1, 2)
|
||||
.reshape([batch_size, seq_length_1, d_model]);
|
||||
|
||||
let context = cache
|
||||
.output
|
||||
.forward_autoregressive(context, 1, |context| self.output.forward(context));
|
||||
|
||||
MhaOutput { weights, context }
|
||||
}
|
||||
|
||||
/// Create an empty autoregressive cache.
|
||||
pub fn new_autoregressive_cache(&self) -> MHAAutoregressiveCache<B> {
|
||||
MHAAutoregressiveCache::default()
|
||||
}
|
||||
|
||||
fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let attn_scores = query
|
||||
.matmul(&key.transpose())
|
||||
|
@ -195,10 +241,21 @@ impl<B: Backend> MultiHeadAttention<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Autoregressive cache for the [Multi Head Attention](MultiHeadAttention) layer.
|
||||
///
|
||||
/// To be used during inference when decoding tokens.
|
||||
#[derive(Default)]
|
||||
pub struct MHAAutoregressiveCache<B: Backend> {
|
||||
query: TensorCache<B, 4>,
|
||||
key: TensorCache<B, 4>,
|
||||
value: TensorCache<B, 4>,
|
||||
output: TensorCache<B, 3>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
|
||||
use burn::tensor::{Distribution, Shape};
|
||||
|
||||
#[test]
|
||||
|
@ -253,7 +310,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_self_attention_mask_pad() {
|
||||
fn test_self_attention_mask_pad() {
|
||||
let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
|
||||
let mha = MultiHeadAttention::new(&MultiHeadAttentionConfig::new(d_model, n_heads));
|
||||
|
||||
|
@ -298,4 +355,38 @@ mod tests {
|
|||
3,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
|
||||
let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];
|
||||
let mha = MultiHeadAttention::new(&MultiHeadAttentionConfig::new(d_model, n_heads));
|
||||
|
||||
let tensor = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_length, d_model],
|
||||
Distribution::Standard,
|
||||
);
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device());
|
||||
let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
|
||||
|
||||
let output_1 = mha.forward(input);
|
||||
let mut output_2 = Vec::new();
|
||||
let mut cache = mha.new_autoregressive_cache();
|
||||
|
||||
for i in 1..seq_length + 1 {
|
||||
let tensor = tensor.index([0..batch_size, 0..i, 0..d_model]);
|
||||
let input = MhaInput::self_attn(tensor);
|
||||
let next_tok = mha
|
||||
.forward_autoregressive_inference(input, &mut cache)
|
||||
.context
|
||||
.index([0..batch_size, i - 1..i, 0..d_model]);
|
||||
output_2.push(next_tok);
|
||||
}
|
||||
|
||||
let output_2 = Tensor::cat(output_2, 1);
|
||||
|
||||
output_1
|
||||
.context
|
||||
.into_data()
|
||||
.assert_approx_eq(&output_2.into_data(), 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
mod mask;
|
||||
mod mha;
|
||||
|
||||
pub use mask::*;
|
||||
pub use mha::*;
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
use super::TensorCache;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
impl<B: Backend, const D: usize> TensorCache<B, D> {
|
||||
pub(crate) fn forward_autoregressive<F>(
|
||||
&mut self,
|
||||
tensor: Tensor<B, 3>,
|
||||
dim_cat: usize,
|
||||
func: F,
|
||||
) -> Tensor<B, D>
|
||||
where
|
||||
F: Fn(Tensor<B, 3>) -> Tensor<B, D>,
|
||||
{
|
||||
let mut tensor_old = None;
|
||||
std::mem::swap(&mut self.state, &mut tensor_old);
|
||||
|
||||
let tensor_new = match tensor_old {
|
||||
Some(tensor_old) => {
|
||||
let [batch_size, seq_length, d_model] = tensor.dims();
|
||||
let next_seq_token =
|
||||
tensor.index([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]);
|
||||
let next_seq_token = func(next_seq_token);
|
||||
|
||||
Tensor::cat(vec![tensor_old, next_seq_token], dim_cat)
|
||||
}
|
||||
None => func(tensor),
|
||||
};
|
||||
|
||||
self.state = Some(tensor_new.clone());
|
||||
tensor_new
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct TensorCache<B: Backend, const D: usize> {
|
||||
pub(crate) state: Option<Tensor<B, D>>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> TensorCache<B, D> {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
mod autoregressive;
|
||||
mod base;
|
||||
|
||||
pub use autoregressive::*;
|
||||
pub use base::*;
|
|
@ -1,4 +1,5 @@
|
|||
pub mod attention;
|
||||
pub mod cache;
|
||||
pub mod transformer;
|
||||
|
||||
mod dropout;
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use crate as burn;
|
||||
use crate::{
|
||||
self as burn,
|
||||
nn::{attention::MHAAutoregressiveCache, cache::TensorCache},
|
||||
};
|
||||
|
||||
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
|
||||
use crate::{
|
||||
|
@ -96,6 +99,38 @@ impl<B: Backend> TransformerEncoder<B> {
|
|||
|
||||
x
|
||||
}
|
||||
/// Applies the forward pass on the input tensor using autoregressive cache.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - tensor: `[batch_size, seq_length, d_model]`
|
||||
/// - output: `[batch_size, seq_length, d_model]`
|
||||
pub fn forward_autoregressive_inference(
|
||||
&self,
|
||||
input: TransformerEncoderInput<B>,
|
||||
cache: &mut TransformerEncoderAutoregressiveCache<B>,
|
||||
) -> Tensor<B, 3> {
|
||||
let mut x = input.tensor;
|
||||
|
||||
for i in 0..self.layers.len() {
|
||||
let layer = self.layers.get(i).unwrap();
|
||||
let cache = cache.layers.get_mut(i).unwrap();
|
||||
|
||||
x = layer.forward_autoregressive_inference(
|
||||
x,
|
||||
input.mask_pad.clone(),
|
||||
input.mask_attn.clone(),
|
||||
cache,
|
||||
);
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// Create an empty autoregressive cache.
|
||||
pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
|
||||
TransformerEncoderAutoregressiveCache::empty(self.layers.len())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
|
@ -156,4 +191,111 @@ impl<B: Backend> TransformerEncoderLayer<B> {
|
|||
|
||||
self.norm_2.forward(x_2)
|
||||
}
|
||||
|
||||
fn forward_autoregressive_inference(
|
||||
&self,
|
||||
input: Tensor<B, 3>,
|
||||
mask_pad: Option<BoolTensor<B, 2>>,
|
||||
mask_attn: Option<BoolTensor<B, 3>>,
|
||||
cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
|
||||
) -> Tensor<B, 3> {
|
||||
let mut input_mhs = MhaInput::self_attn(input.clone());
|
||||
|
||||
if let Some(mask_pad) = mask_pad {
|
||||
input_mhs = input_mhs.mask_pad(mask_pad);
|
||||
}
|
||||
|
||||
if let Some(mask_attn) = mask_attn {
|
||||
input_mhs = input_mhs.mask_attn(mask_attn);
|
||||
}
|
||||
|
||||
let x_1 = self
|
||||
.mha
|
||||
.forward_autoregressive_inference(input_mhs, &mut cache.mha);
|
||||
let x_1 = self.dropout.forward(x_1.context) + input;
|
||||
let x_1 = cache
|
||||
.norm_1
|
||||
.forward_autoregressive(x_1, 1, |x_1| self.norm_1.forward(x_1));
|
||||
|
||||
let x_2 = cache
|
||||
.pwff
|
||||
.forward_autoregressive(x_1.clone(), 1, |x_1| self.pwff.forward(x_1));
|
||||
let x_2 = self.dropout.forward(x_2) + x_1;
|
||||
|
||||
cache
|
||||
.norm_2
|
||||
.forward_autoregressive(x_2, 1, |x_2| self.norm_2.forward(x_2))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
|
||||
mha: MHAAutoregressiveCache<B>,
|
||||
pwff: TensorCache<B, 3>,
|
||||
norm_1: TensorCache<B, 3>,
|
||||
norm_2: TensorCache<B, 3>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
|
||||
fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer.
|
||||
///
|
||||
/// To be used during inference when decoding tokens.
|
||||
pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
|
||||
layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
|
||||
fn empty(num_layers: usize) -> Self {
|
||||
Self {
|
||||
layers: (0..num_layers)
|
||||
.map(|_| TransformerEncoderLayerAutoregressiveCache::new())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
|
||||
use burn_tensor::Distribution;
|
||||
|
||||
#[test]
|
||||
fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
|
||||
let [batch_size, seq_length, d_model, d_ff, n_heads, num_layers] = [3, 4, 12, 24, 2, 3];
|
||||
let transformer = TransformerEncoder::new(&TransformerEncoderConfig::new(
|
||||
d_model, d_ff, n_heads, num_layers,
|
||||
));
|
||||
|
||||
let tensor = Tensor::<TestBackend, 3>::random(
|
||||
[batch_size, seq_length, d_model],
|
||||
Distribution::Standard,
|
||||
);
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device());
|
||||
let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
|
||||
|
||||
let output_1 = transformer.forward(input);
|
||||
let mut output_2 = Vec::new();
|
||||
let mut cache = transformer.new_autoregressive_cache();
|
||||
|
||||
for i in 1..seq_length + 1 {
|
||||
let tensor = tensor.index([0..batch_size, 0..i, 0..d_model]);
|
||||
let input = TransformerEncoderInput::new(tensor.clone());
|
||||
let next_tok = transformer
|
||||
.forward_autoregressive_inference(input, &mut cache)
|
||||
.index([0..batch_size, i - 1..i, 0..d_model]);
|
||||
output_2.push(next_tok);
|
||||
}
|
||||
|
||||
let output_2 = Tensor::cat(output_2, 1);
|
||||
|
||||
output_1
|
||||
.into_data()
|
||||
.assert_approx_eq(&output_2.into_data(), 3);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue