Add seq start position when applying RoPE encoding (#1796)

This commit is contained in:
Guillaume Lagrange 2024-05-22 13:18:31 -04:00 committed by GitHub
parent 0918cf00c6
commit b466fd7606
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 21 additions and 1 deletions

View File

@ -104,6 +104,22 @@ impl<B: Backend> RotaryEncoding<B> {
///
/// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
self.apply(x, 0)
}
/// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
///
/// Arguments:
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
/// respectively.
/// * `start` - Sequence start position index.
///
/// Returns:
/// * Output tensor with the same shape as input tensor after applying rotary encoding.
///
/// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
assert!(
D >= 2,
"Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
@ -127,7 +143,11 @@ impl<B: Backend> RotaryEncoding<B> {
.reshape([dummy_dim_size, seq_len, d_model / 2, 2])
.matmul(sign_tensor.unsqueeze())
.reshape([dummy_dim_size, seq_len, d_model, 2])
* self.freq_complex.clone().slice([0..seq_len]).unsqueeze();
* self
.freq_complex
.clone()
.slice([start..start + seq_len])
.unsqueeze();
// Sum the real and imaginary components to get output tensor and reshape to original shape
out.sum_dim(D - 1).reshape(input_shape)