Stop confusing the TF compiler with ModelOutput objects (#28712)

* Stop confusing the TF compiler with ModelOutput objects

* Stop confusing the TF compiler with ModelOutput objects
This commit is contained in:
Matt 2024-01-26 12:22:29 +00:00 committed by GitHub
parent a638de1987
commit 708b19eb09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 6 deletions

View File

@ -1171,7 +1171,7 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
labels=labels,
return_dict=return_dict,
return_dict=False,
training=training,
)
@ -1179,12 +1179,19 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
return tuple(output for output in outputs if output is not None)
if outputs.loss is not None and outputs.loss.shape.rank == 0:
outputs.loss = tf.reshape(outputs.loss, (1,))
if labels is not None:
loss = outputs[0]
logits = outputs[1]
else:
loss = None
logits = outputs[0]
if loss is not None and loss.shape.rank == 0:
loss = tf.reshape(loss, (1,))
return TFBlipForConditionalGenerationModelOutput(
loss=outputs.loss,
logits=outputs.logits,
loss=loss,
logits=logits,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,

View File

@ -1060,7 +1060,8 @@ class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel):
labels = labels[:, 1:]
labels = tf.reshape(labels, (-1,))
# Keras won't give us label smoothing for sparse CE, so we de-sparsify things here
one_hot_labels = tf.one_hot(labels, depth=self.config.vocab_size, dtype=tf.float32)
# Use relu to clamp masked labels at 0 to avoid NaN (we will be zeroing those out later anyway)
one_hot_labels = tf.one_hot(tf.nn.relu(labels), depth=self.config.vocab_size, dtype=tf.float32)
loss_fct = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1, reduction="none")
masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32)
lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores)