[`TF`] Also apply patch to support left padding (#25085)

* tf versions

* apply changes to other models

* 3 models slipped through the cracks
This commit is contained in:
Arthur 2023-07-25 17:23:09 +02:00 committed by GitHub
parent f104522718
commit 2fac342238
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 19 additions and 43 deletions

View File

@ -785,7 +785,9 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
logger.warning(

View File

@ -798,16 +798,10 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
else:
if input_ids is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1

View File

@ -1082,16 +1082,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
else:
if input_ids is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1

View File

@ -867,16 +867,10 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
else:
if input_ids is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1

View File

@ -813,7 +813,9 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
logger.warning(

View File

@ -809,16 +809,10 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
else:
if input_ids is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1

View File

@ -1066,16 +1066,10 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
else:
if input_ids is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(
tf.math.not_equal(input_ids, self.config.pad_token_id),
dtype=input_ids.dtype,
),
-1,
keepdims=False,
)
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1
)
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1

View File

@ -1247,7 +1247,9 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
logits.device
)
else:
sequence_lengths = -1
logger.warning(