From fb1da53a380bab80312ee971f1dcb2740efde93c Mon Sep 17 00:00:00 2001 From: Aasheesh Singh <20820983+ashdtu@users.noreply.github.com> Date: Fri, 12 Apr 2024 11:45:49 -0400 Subject: [PATCH] support for rotary positional encoding to transformer modules. (#1604) * add rotary positional encoding to transformer modules. * fix f64 error * use num_traits * add panic condition --- burn-book/src/building-blocks/module.md | 17 +- crates/burn-core/src/nn/mod.rs | 2 + crates/burn-core/src/nn/rope_encoding.rs | 219 +++++++++++++++++++++++ 3 files changed, 230 insertions(+), 8 deletions(-) create mode 100644 crates/burn-core/src/nn/rope_encoding.rs diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 51cde36ec..0a5d81eaa 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -52,7 +52,7 @@ the `Module` derive, you need to be careful to achieve the behavior you want. These methods are available for all modules. | Burn API | PyTorch Equivalent | -| --------------------------------------- | ---------------------------------------- | +|-----------------------------------------|------------------------------------------| | `module.devices()` | N/A | | `module.fork(device)` | Similar to `module.to(device).detach()` | | `module.to_device(device)` | `module.to(device)` | @@ -69,7 +69,7 @@ Similar to the backend trait, there is also the `AutodiffModule` trait to signif autodiff support. | Burn API | PyTorch Equivalent | -| ---------------- | ------------------ | +|------------------|--------------------| | `module.valid()` | `module.eval()` | ## Visitor & Mapper @@ -107,7 +107,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### General | Burn API | PyTorch Equivalent | -| -------------- | --------------------------------------------- | +|----------------|-----------------------------------------------| | `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. | | `Dropout` | `nn.Dropout` | | `Embedding` | `nn.Embedding` | @@ -125,7 +125,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### Convolutions | Burn API | PyTorch Equivalent | -| ----------------- | -------------------- | +|-------------------|----------------------| | `Conv1d` | `nn.Conv1d` | | `Conv2d` | `nn.Conv2d` | | `ConvTranspose1d` | `nn.ConvTranspose1d` | @@ -134,7 +134,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### Pooling | Burn API | PyTorch Equivalent | -| ------------------- | ---------------------- | +|---------------------|------------------------| | `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` | | `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` | | `AvgPool1d` | `nn.AvgPool1d` | @@ -145,7 +145,7 @@ Burn comes with built-in modules that you can use to build your own modules. ### RNNs | Burn API | PyTorch Equivalent | -| ---------------- | ---------------------- | +|------------------|------------------------| | `Gru` | `nn.GRU` | | `Lstm` | `nn.LSTM` | | `GateController` | _No direct equivalent_ | @@ -153,16 +153,17 @@ Burn comes with built-in modules that you can use to build your own modules. ### Transformer | Burn API | PyTorch Equivalent | -| -------------------- | ----------------------- | +|----------------------|-------------------------| | `MultiHeadAttention` | `nn.MultiheadAttention` | | `TransformerDecoder` | `nn.TransformerDecoder` | | `TransformerEncoder` | `nn.TransformerEncoder` | | `PositionalEncoding` | _No direct equivalent_ | +| `RotaryEncoding` | _No direct equivalent_ | ### Loss | Burn API | PyTorch Equivalent | -| ------------------ | --------------------- | +|--------------------|-----------------------| | `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `MseLoss` | `nn.MSELoss` | | `HuberLoss` | `nn.HuberLoss` | diff --git a/crates/burn-core/src/nn/mod.rs b/crates/burn-core/src/nn/mod.rs index ab462a1cb..2fada289e 100644 --- a/crates/burn-core/src/nn/mod.rs +++ b/crates/burn-core/src/nn/mod.rs @@ -28,6 +28,7 @@ mod pos_encoding; mod prelu; mod relu; mod rnn; +mod rope_encoding; mod swiglu; mod unfold; @@ -43,5 +44,6 @@ pub use pos_encoding::*; pub use prelu::*; pub use relu::*; pub use rnn::*; +pub use rope_encoding::*; pub use swiglu::*; pub use unfold::*; diff --git a/crates/burn-core/src/nn/rope_encoding.rs b/crates/burn-core/src/nn/rope_encoding.rs new file mode 100644 index 000000000..6eb732e4b --- /dev/null +++ b/crates/burn-core/src/nn/rope_encoding.rs @@ -0,0 +1,219 @@ +use crate as burn; +use crate::config::Config; +use crate::module::Module; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; +use alloc::vec; +use burn_tensor::Int; + +#[cfg(not(feature = "std"))] +use num_traits::Float; + +/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer. +#[derive(Config, Debug)] +pub struct RotaryEncodingConfig { + /// Maximum sequence length of input + max_sequence_length: usize, + + /// Size of the input embedding or hidden dimension + d_model: usize, + + /// Scaling factor for frequency computation. Defaults to 10000.0 + #[config(default = "10000.0")] + theta: f32, +} + +impl RotaryEncodingConfig { + /// Initialize a new [RotaryEncoding](RotaryEncoding) module. + /// + /// # Panics + /// + /// Panics if the size of input embedding dimension is not even. + /// Panics if the theta parameter is not positive. + pub fn init(&self, device: &B::Device) -> RotaryEncoding { + assert_eq!( + self.d_model % 2, + 0, + "The input embedding dimension must be even" + ); + assert!( + self.theta > 0.0, + "Theta parameter must be positive (default: 10000)." + ); + + // Calculate the rotation frequencies for positional embeddings based on the formula + // `theta_i = 1 / (10000 ^ (2i / d_model)) for i in [0..d_model/2]` + let exponent = Tensor::::arange_step(0..self.d_model as i64, 2, device) + .float() + .div_scalar(self.d_model as f32); + + // Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))` + // This is done since burn doesn't support exponentiation of scalar to tensor + let theta_i = exponent.mul_scalar(self.theta.ln()).exp(); + let theta_i = theta_i.powf_scalar(-1.0); + + // Generate frequency values for positional embeddings + let frequencies: Tensor = + Tensor::::arange(0..self.max_sequence_length as i64, device) + .float() + .unsqueeze() + .transpose() + .repeat(1, self.d_model / 2) + * theta_i.unsqueeze(); + + // Convert frequency values to complex numbers (polar form) + let p_cos = frequencies.clone().cos(); + let p_sin = frequencies.sin(); + + // Create the frequency tensor of shape (max_sequence_length, d_model, 2) with the real(cos) + // and imaginary(sin) components along last dimension + let freq_complex: Tensor = Tensor::cat(vec![p_cos, p_sin], 1) + .reshape([self.max_sequence_length, 2, self.d_model / 2]) + .transpose() + .unsqueeze_dim::<4>(2) + .repeat(2, 2) + .reshape([self.max_sequence_length, self.d_model, 2]); + + RotaryEncoding { freq_complex } + } +} + +/// A module that applies rotary positional encoding to a tensor. +/// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes +/// absolute positional information with rotation matrix and naturally incorporates +/// explicit relative position dependency in self-attention formulation. +/// +/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) +#[derive(Module, Debug)] +pub struct RotaryEncoding { + /// Frequency Tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components + freq_complex: Tensor, +} + +#[allow(clippy::single_range_in_vec_init)] +impl RotaryEncoding { + /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model) + /// + /// Arguments: + /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors + /// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim) + /// respectively. + /// + /// Returns: + /// * Output tensor with the same shape as input tensor after applying rotary encoding. + /// + /// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension. + pub fn forward(&self, x: Tensor) -> Tensor { + assert!( + D >= 2, + "Input tensor must have at least 2 dimensions for sequence length and hidden dimension" + ); + + let device = x.device(); + let input_shape = x.shape(); + + // Extract the sequence length and embedding dimension, other dimensions are kept generic + // to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads) + let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]); + let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model); + + // Create a dummy tensor with signed ones based on the 2D rotation matrix + // [[cos, -sin], [sin, cos]] + let sign_tensor = + Tensor::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device); + + // Rotate input using the frequency tensor. Slice the frequencies till input sequence length + let out: Tensor = x + .reshape([dummy_dim_size, seq_len, d_model / 2, 2]) + .matmul(sign_tensor.unsqueeze()) + .reshape([dummy_dim_size, seq_len, d_model, 2]) + * self.freq_complex.clone().slice([0..seq_len]).unsqueeze(); + + // Sum the real and imaginary components to get output tensor and reshape to original shape + out.sum_dim(D - 1).reshape(input_shape) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + + #[test] + fn test_rotary_encoding_forward() { + let device = Default::default(); + let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::(&device); + + let input = Tensor::from_floats( + [ + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], + [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]], + ], + &device, + ); + + // Input = [Batch size, Num of heads, Seq_len, d_model] + let input = input.unsqueeze::<4>(); + + let output = rotary_encoding.forward(input); + let expected_output = Tensor::::from_floats( + [ + [ + [1.0000, 2.0000, 3.0000, 4.0000], + [-2.3473, 7.4492, 6.9197, 8.0696], + ], + [ + [9.0000, 10.0000, 11.0000, 12.0000], + [-4.7567, 18.5034, 14.8393, 16.1492], + ], + ], + &device, + ); + + output + .squeeze(0) + .to_data() + .assert_approx_eq(&expected_output.to_data(), 4); + } + + #[test] + fn test_zero_input_rotary_encoding_forward() { + let device = Default::default(); + let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::(&device); + + // Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well + let input = Tensor::zeros([1, 2, 2, 4], &device); + + let output = rotary_encoding.forward(input); + let expected_output = Tensor::::from_floats( + [ + [ + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + ], + [ + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + ], + ], + &device, + ); + + output + .squeeze(0) + .to_data() + .assert_approx_eq(&expected_output.to_data(), 4); + } + + #[test] + #[should_panic] + fn test_valid_input_hidden_dim() { + // Hidden dimension must be even to be able to split into real and imaginary components + // for rotation + let d_model = 15; + let device = Default::default(); + let pe = RotaryEncodingConfig::new(10, d_model).init::(&device); + let input = Tensor::zeros([1, 5, d_model], &device); + let _output = pe.forward(input); + } +}