[`TF`] Also apply patch to support left padding (#25085)
* tf versions * apply changes to other models * 3 models slipped through the cracks
This commit is contained in:
parent
f104522718
commit
2fac342238
|
@ -785,7 +785,9 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
|||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
||||
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
|
||||
logits.device
|
||||
)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
|
|
|
@ -798,16 +798,10 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
|
|||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(
|
||||
tf.math.not_equal(input_ids, self.config.pad_token_id),
|
||||
dtype=input_ids.dtype,
|
||||
),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
|
|
@ -1082,16 +1082,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
|||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(
|
||||
tf.math.not_equal(input_ids, self.config.pad_token_id),
|
||||
dtype=input_ids.dtype,
|
||||
),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
|
|
@ -867,16 +867,10 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
|
|||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(
|
||||
tf.math.not_equal(input_ids, self.config.pad_token_id),
|
||||
dtype=input_ids.dtype,
|
||||
),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
|
|
@ -813,7 +813,9 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
|
|||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
||||
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
|
||||
logits.device
|
||||
)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
|
|
|
@ -809,16 +809,10 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
|
|||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(
|
||||
tf.math.not_equal(input_ids, self.config.pad_token_id),
|
||||
dtype=input_ids.dtype,
|
||||
),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
|
|
@ -1066,16 +1066,10 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
|
|||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(
|
||||
tf.math.not_equal(input_ids, self.config.pad_token_id),
|
||||
dtype=input_ids.dtype,
|
||||
),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
|
|
@ -1247,7 +1247,9 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
|
|||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
||||
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
|
||||
logits.device
|
||||
)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
|
|
Loading…
Reference in New Issue