Stop confusing the TF compiler with ModelOutput objects (#28712)
* Stop confusing the TF compiler with ModelOutput objects * Stop confusing the TF compiler with ModelOutput objects
This commit is contained in:
parent
a638de1987
commit
708b19eb09
|
@ -1171,7 +1171,7 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
|
|||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=image_embeds,
|
||||
labels=labels,
|
||||
return_dict=return_dict,
|
||||
return_dict=False,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
@ -1179,12 +1179,19 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
|
|||
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
||||
return tuple(output for output in outputs if output is not None)
|
||||
|
||||
if outputs.loss is not None and outputs.loss.shape.rank == 0:
|
||||
outputs.loss = tf.reshape(outputs.loss, (1,))
|
||||
if labels is not None:
|
||||
loss = outputs[0]
|
||||
logits = outputs[1]
|
||||
else:
|
||||
loss = None
|
||||
logits = outputs[0]
|
||||
|
||||
if loss is not None and loss.shape.rank == 0:
|
||||
loss = tf.reshape(loss, (1,))
|
||||
|
||||
return TFBlipForConditionalGenerationModelOutput(
|
||||
loss=outputs.loss,
|
||||
logits=outputs.logits,
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
image_embeds=image_embeds,
|
||||
last_hidden_state=vision_outputs.last_hidden_state,
|
||||
hidden_states=vision_outputs.hidden_states,
|
||||
|
|
|
@ -1060,7 +1060,8 @@ class TFBlipTextLMHeadModel(TFBlipTextPreTrainedModel):
|
|||
labels = labels[:, 1:]
|
||||
labels = tf.reshape(labels, (-1,))
|
||||
# Keras won't give us label smoothing for sparse CE, so we de-sparsify things here
|
||||
one_hot_labels = tf.one_hot(labels, depth=self.config.vocab_size, dtype=tf.float32)
|
||||
# Use relu to clamp masked labels at 0 to avoid NaN (we will be zeroing those out later anyway)
|
||||
one_hot_labels = tf.one_hot(tf.nn.relu(labels), depth=self.config.vocab_size, dtype=tf.float32)
|
||||
loss_fct = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1, reduction="none")
|
||||
masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32)
|
||||
lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores)
|
||||
|
|
Loading…
Reference in New Issue