fixed whisper positional encoding (#23167)

This commit is contained in:
Andrei Filatov 2023-05-05 18:36:15 +03:00 committed by GitHub
parent 1b9c352e55
commit 77412343c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -128,7 +128,7 @@ class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
def call(self, input_ids, past_key_values_length=0):
past_key_values_length = tf.cast(past_key_values_length, tf.int32)
gather_indices = tf.range(tf.shape(input_ids)[-1], delta=1) + past_key_values_length
gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length
return tf.gather(self.weight, gather_indices)

View File

@ -226,7 +226,7 @@ class WhisperPositionalEmbedding(nn.Embedding):
super().__init__(num_positions, embedding_dim)
def forward(self, input_ids, past_key_values_length=0):
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[-1]]
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
class WhisperAttention(nn.Module):