Feat/autoregressive transformer (#125)

This commit is contained in:
Nathaniel Simard 2022-12-10 15:47:57 -05:00 committed by GitHub
parent b99b23e1a7
commit 63d8d39517
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 360 additions and 9 deletions

View File

@ -6,7 +6,7 @@ use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use crate::{to_nd_array_tensor, NdArrayDevice, SEED};
use burn_tensor::Distribution;
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
use ndarray::{Axis, Dim, IxDyn, SliceInfoElem};
use ndarray::{ArcArray, Axis, Dim, Dimension, IxDyn, SliceInfoElem};
use rand::rngs::StdRng;
use rand::SeedableRng;
@ -527,14 +527,12 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
}
fn cat<const D: usize>(tensors: &[NdArrayTensor<E, D>], dim: usize) -> NdArrayTensor<E, D> {
let mut shape = tensors.get(0).unwrap().shape;
shape.dims[dim] = tensors.len();
let arrays: Vec<ndarray::ArrayView<E, IxDyn>> =
tensors.iter().map(|t| t.array.view()).collect();
let array = ndarray::concatenate(Axis(dim), &arrays)
.unwrap()
.into_shared();
let shape = array_shape(&array);
NdArrayTensor { array, shape }
}
@ -646,3 +644,13 @@ fn cmp_min(a: &f64, b: &f64) -> Ordering {
}
Ordering::Equal
}
fn array_shape<const D: usize, E>(array: &ArcArray<E, IxDyn>) -> Shape<D> {
let dims = array
.raw_dim()
.slice()
.iter()
.map(|a| *a as i64)
.collect::<Vec<_>>();
Shape::from(dims)
}

View File

@ -399,6 +399,12 @@ where
Self::new(tensor)
}
/// Create a tensor of the given shape where each element is one.
pub fn ones_device<S: Into<Shape<D>>>(shape: S, device: B::Device) -> Self {
let tensor = B::ones(shape.into(), device);
Self::new(tensor)
}
/// Returns a tensor containing the elements selected from the given ranges.
///
/// # Panics

View File

@ -0,0 +1,50 @@
use burn_tensor::{backend::Backend, BoolTensor, ElementConversion, Tensor};
/// Generate an autoregressive attention mask.
///
/// The mask can be used in Transformer modules to train models to generate tensors sequentially.
pub fn generate_autoregressive_mask<B: Backend>(
batch_size: usize,
seq_length: usize,
device: B::Device,
) -> BoolTensor<B, 3> {
let mut mask = Tensor::<B::IntegerBackend, 3>::zeros([1, seq_length, seq_length]);
for i in 0..seq_length {
let values = Tensor::<B::IntegerBackend, 3>::ones([1, 1, seq_length - (i + 1)]);
mask = mask.index_assign([0..1, i..i + 1, i + 1..seq_length], &values);
}
mask = mask.to_device(device).repeat(0, batch_size);
BoolTensor::from_int_backend(mask.equal_scalar(1_i64.to_elem::<i64>()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn_tensor::Data;
#[test]
fn test_generate_autoregressive_mask() {
let device = <TestBackend as Backend>::Device::default();
let mask = generate_autoregressive_mask::<TestBackend>(2, 3, device);
assert_eq!(
mask.into_data(),
Data::from([
[
[false, true, true],
[false, false, true],
[false, false, false],
],
[
[false, true, true],
[false, false, true],
[false, false, false],
]
])
);
}
}

View File

@ -1,5 +1,6 @@
use crate as burn;
use crate::nn::cache::TensorCache;
use crate::{
config::Config,
module::{Module, Param},
@ -46,7 +47,7 @@ pub struct MultiHeadAttention<B: Backend> {
}
/// [Multihead attention](MultiHeadAttention) forward pass input argument.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct MhaInput<B: Backend> {
query: Tensor<B, 3>,
key: Tensor<B, 3>,
@ -93,7 +94,7 @@ impl<B: Backend> MhaInput<B> {
}
/// [Multihead attention](MultiHeadAttention) outputs.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct MhaOutput<B: Backend> {
/// The attention weights [batch_size, seq_length_1, seq_length_2].
pub weights: Tensor<B, 4>,
@ -151,6 +152,51 @@ impl<B: Backend> MultiHeadAttention<B> {
MhaOutput { weights, context }
}
/// Applies the forward pass on the input tensors using an autoregressive cache.
///
/// # Shapes
///
/// - query: `[batch_size, seq_length_1, d_model]`
/// - key: `[batch_size, seq_length_2, d_model]`
/// - value: `[batch_size, seq_length_2, d_model]`
/// - output: `[batch_size, seq_length_1, d_model]`
pub fn forward_autoregressive_inference(
&self,
input: MhaInput<B>,
cache: &mut MHAAutoregressiveCache<B>,
) -> MhaOutput<B> {
let [batch_size, seq_length_1, d_model] = input.query.dims();
let attention_linear = |cache: &mut TensorCache<B, 4>,
tensor: Tensor<B, 3>,
param: &Param<nn::Linear<B>>| {
cache.forward_autoregressive(tensor, 2, |tensor| self.attention_linear(tensor, param))
};
let query = attention_linear(&mut cache.query, input.query, &self.query);
let key = attention_linear(&mut cache.key, input.key, &self.key);
let value = attention_linear(&mut cache.value, input.value, &self.value);
let attn_scores = self.attn_scores(query, key);
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
let context = weights.matmul(&value);
let context = context
.swap_dims(1, 2)
.reshape([batch_size, seq_length_1, d_model]);
let context = cache
.output
.forward_autoregressive(context, 1, |context| self.output.forward(context));
MhaOutput { weights, context }
}
/// Create an empty autoregressive cache.
pub fn new_autoregressive_cache(&self) -> MHAAutoregressiveCache<B> {
MHAAutoregressiveCache::default()
}
fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
let attn_scores = query
.matmul(&key.transpose())
@ -195,10 +241,21 @@ impl<B: Backend> MultiHeadAttention<B> {
}
}
/// Autoregressive cache for the [Multi Head Attention](MultiHeadAttention) layer.
///
/// To be used during inference when decoding tokens.
#[derive(Default)]
pub struct MHAAutoregressiveCache<B: Backend> {
query: TensorCache<B, 4>,
key: TensorCache<B, 4>,
value: TensorCache<B, 4>,
output: TensorCache<B, 3>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
use burn::tensor::{Distribution, Shape};
#[test]
@ -253,7 +310,7 @@ mod tests {
}
#[test]
pub fn test_self_attention_mask_pad() {
fn test_self_attention_mask_pad() {
let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
let mha = MultiHeadAttention::new(&MultiHeadAttentionConfig::new(d_model, n_heads));
@ -298,4 +355,38 @@ mod tests {
3,
);
}
#[test]
fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];
let mha = MultiHeadAttention::new(&MultiHeadAttentionConfig::new(d_model, n_heads));
let tensor = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Standard,
);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device());
let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
let output_1 = mha.forward(input);
let mut output_2 = Vec::new();
let mut cache = mha.new_autoregressive_cache();
for i in 1..seq_length + 1 {
let tensor = tensor.index([0..batch_size, 0..i, 0..d_model]);
let input = MhaInput::self_attn(tensor);
let next_tok = mha
.forward_autoregressive_inference(input, &mut cache)
.context
.index([0..batch_size, i - 1..i, 0..d_model]);
output_2.push(next_tok);
}
let output_2 = Tensor::cat(output_2, 1);
output_1
.context
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
}
}

View File

@ -1,3 +1,5 @@
mod mask;
mod mha;
pub use mask::*;
pub use mha::*;

33
burn/src/nn/cache/autoregressive.rs vendored Normal file
View File

@ -0,0 +1,33 @@
use super::TensorCache;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
impl<B: Backend, const D: usize> TensorCache<B, D> {
pub(crate) fn forward_autoregressive<F>(
&mut self,
tensor: Tensor<B, 3>,
dim_cat: usize,
func: F,
) -> Tensor<B, D>
where
F: Fn(Tensor<B, 3>) -> Tensor<B, D>,
{
let mut tensor_old = None;
std::mem::swap(&mut self.state, &mut tensor_old);
let tensor_new = match tensor_old {
Some(tensor_old) => {
let [batch_size, seq_length, d_model] = tensor.dims();
let next_seq_token =
tensor.index([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]);
let next_seq_token = func(next_seq_token);
Tensor::cat(vec![tensor_old, next_seq_token], dim_cat)
}
None => func(tensor),
};
self.state = Some(tensor_new.clone());
tensor_new
}
}

13
burn/src/nn/cache/base.rs vendored Normal file
View File

@ -0,0 +1,13 @@
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
#[derive(Default)]
pub struct TensorCache<B: Backend, const D: usize> {
pub(crate) state: Option<Tensor<B, D>>,
}
impl<B: Backend, const D: usize> TensorCache<B, D> {
pub fn new() -> Self {
Self::default()
}
}

5
burn/src/nn/cache/mod.rs vendored Normal file
View File

@ -0,0 +1,5 @@
mod autoregressive;
mod base;
pub use autoregressive::*;
pub use base::*;

View File

@ -1,4 +1,5 @@
pub mod attention;
pub mod cache;
pub mod transformer;
mod dropout;

View File

@ -1,4 +1,7 @@
use crate as burn;
use crate::{
self as burn,
nn::{attention::MHAAutoregressiveCache, cache::TensorCache},
};
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
use crate::{
@ -96,6 +99,38 @@ impl<B: Backend> TransformerEncoder<B> {
x
}
/// Applies the forward pass on the input tensor using autoregressive cache.
///
/// # Shapes
///
/// - tensor: `[batch_size, seq_length, d_model]`
/// - output: `[batch_size, seq_length, d_model]`
pub fn forward_autoregressive_inference(
&self,
input: TransformerEncoderInput<B>,
cache: &mut TransformerEncoderAutoregressiveCache<B>,
) -> Tensor<B, 3> {
let mut x = input.tensor;
for i in 0..self.layers.len() {
let layer = self.layers.get(i).unwrap();
let cache = cache.layers.get_mut(i).unwrap();
x = layer.forward_autoregressive_inference(
x,
input.mask_pad.clone(),
input.mask_attn.clone(),
cache,
);
}
x
}
/// Create an empty autoregressive cache.
pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
TransformerEncoderAutoregressiveCache::empty(self.layers.len())
}
}
#[derive(Module, Debug)]
@ -156,4 +191,111 @@ impl<B: Backend> TransformerEncoderLayer<B> {
self.norm_2.forward(x_2)
}
fn forward_autoregressive_inference(
&self,
input: Tensor<B, 3>,
mask_pad: Option<BoolTensor<B, 2>>,
mask_attn: Option<BoolTensor<B, 3>>,
cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
) -> Tensor<B, 3> {
let mut input_mhs = MhaInput::self_attn(input.clone());
if let Some(mask_pad) = mask_pad {
input_mhs = input_mhs.mask_pad(mask_pad);
}
if let Some(mask_attn) = mask_attn {
input_mhs = input_mhs.mask_attn(mask_attn);
}
let x_1 = self
.mha
.forward_autoregressive_inference(input_mhs, &mut cache.mha);
let x_1 = self.dropout.forward(x_1.context) + input;
let x_1 = cache
.norm_1
.forward_autoregressive(x_1, 1, |x_1| self.norm_1.forward(x_1));
let x_2 = cache
.pwff
.forward_autoregressive(x_1.clone(), 1, |x_1| self.pwff.forward(x_1));
let x_2 = self.dropout.forward(x_2) + x_1;
cache
.norm_2
.forward_autoregressive(x_2, 1, |x_2| self.norm_2.forward(x_2))
}
}
#[derive(Default)]
struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
mha: MHAAutoregressiveCache<B>,
pwff: TensorCache<B, 3>,
norm_1: TensorCache<B, 3>,
norm_2: TensorCache<B, 3>,
}
impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
fn new() -> Self {
Self::default()
}
}
/// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer.
///
/// To be used during inference when decoding tokens.
pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
}
impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
fn empty(num_layers: usize) -> Self {
Self {
layers: (0..num_layers)
.map(|_| TransformerEncoderLayerAutoregressiveCache::new())
.collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
use burn_tensor::Distribution;
#[test]
fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
let [batch_size, seq_length, d_model, d_ff, n_heads, num_layers] = [3, 4, 12, 24, 2, 3];
let transformer = TransformerEncoder::new(&TransformerEncoderConfig::new(
d_model, d_ff, n_heads, num_layers,
));
let tensor = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Standard,
);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, tensor.device());
let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
let output_1 = transformer.forward(input);
let mut output_2 = Vec::new();
let mut cache = transformer.new_autoregressive_cache();
for i in 1..seq_length + 1 {
let tensor = tensor.index([0..batch_size, 0..i, 0..d_model]);
let input = TransformerEncoderInput::new(tensor.clone());
let next_tok = transformer
.forward_autoregressive_inference(input, &mut cache)
.index([0..batch_size, i - 1..i, 0..d_model]);
output_2.push(next_tok);
}
let output_2 = Tensor::cat(output_2, 1);
output_1
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
}
}