mirror of https://github.com/tracel-ai/burn.git
Add seq start position when applying RoPE encoding (#1796)
This commit is contained in:
parent
0918cf00c6
commit
b466fd7606
|
@ -104,6 +104,22 @@ impl<B: Backend> RotaryEncoding<B> {
|
|||
///
|
||||
/// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
|
||||
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
|
||||
self.apply(x, 0)
|
||||
}
|
||||
|
||||
/// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
|
||||
///
|
||||
/// Arguments:
|
||||
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
|
||||
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
|
||||
/// respectively.
|
||||
/// * `start` - Sequence start position index.
|
||||
///
|
||||
/// Returns:
|
||||
/// * Output tensor with the same shape as input tensor after applying rotary encoding.
|
||||
///
|
||||
/// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
|
||||
pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
|
||||
assert!(
|
||||
D >= 2,
|
||||
"Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
|
||||
|
@ -127,7 +143,11 @@ impl<B: Backend> RotaryEncoding<B> {
|
|||
.reshape([dummy_dim_size, seq_len, d_model / 2, 2])
|
||||
.matmul(sign_tensor.unsqueeze())
|
||||
.reshape([dummy_dim_size, seq_len, d_model, 2])
|
||||
* self.freq_complex.clone().slice([0..seq_len]).unsqueeze();
|
||||
* self
|
||||
.freq_complex
|
||||
.clone()
|
||||
.slice([start..start + seq_len])
|
||||
.unsqueeze();
|
||||
|
||||
// Sum the real and imaginary components to get output tensor and reshape to original shape
|
||||
out.sum_dim(D - 1).reshape(input_shape)
|
||||
|
|
Loading…
Reference in New Issue