fixed whisper positional encoding (#23167)
This commit is contained in:
parent
1b9c352e55
commit
77412343c8
|
@ -128,7 +128,7 @@ class TFWhisperPositionalEmbedding(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, input_ids, past_key_values_length=0):
|
||||
past_key_values_length = tf.cast(past_key_values_length, tf.int32)
|
||||
gather_indices = tf.range(tf.shape(input_ids)[-1], delta=1) + past_key_values_length
|
||||
gather_indices = tf.range(tf.shape(input_ids)[1], delta=1) + past_key_values_length
|
||||
return tf.gather(self.weight, gather_indices)
|
||||
|
||||
|
||||
|
|
|
@ -226,7 +226,7 @@ class WhisperPositionalEmbedding(nn.Embedding):
|
|||
super().__init__(num_positions, embedding_dim)
|
||||
|
||||
def forward(self, input_ids, past_key_values_length=0):
|
||||
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[-1]]
|
||||
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
|
||||
|
||||
|
||||
class WhisperAttention(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue