Modified TF train_step (#13678)
Allows models to be compiled without a loss, and to use the internal loss computations for training with fit()
This commit is contained in:
parent
e00bc7cd2f
commit
367c2ef53b
|
@ -26,6 +26,7 @@ import h5py
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import data_adapter
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
@ -54,6 +55,10 @@ TFModelInputType = Union[
|
|||
]
|
||||
|
||||
|
||||
def dummy_loss(y_true, y_pred):
|
||||
return tf.reduce_mean(y_pred)
|
||||
|
||||
|
||||
class TFModelUtilsMixin:
|
||||
"""
|
||||
A few utilities for :obj:`tf.keras.Model`, to be used as a mixin.
|
||||
|
@ -296,8 +301,7 @@ def booleans_processing(config, **kwargs):
|
|||
if (
|
||||
kwargs["output_attentions"] not in (None, config.output_attentions)
|
||||
or kwargs["output_hidden_states"] not in (None, config.output_hidden_states)
|
||||
or "use_cache" in kwargs
|
||||
and kwargs["use_cache"] not in (None, config.use_cache)
|
||||
or ("use_cache" in kwargs and kwargs["use_cache"] not in (None, config.use_cache))
|
||||
):
|
||||
tf_logger.warning(
|
||||
"The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model."
|
||||
|
@ -712,6 +716,108 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def compile(
|
||||
self,
|
||||
optimizer="rmsprop",
|
||||
loss="passthrough",
|
||||
metrics=None,
|
||||
loss_weights=None,
|
||||
weighted_metrics=None,
|
||||
run_eagerly=None,
|
||||
steps_per_execution=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
|
||||
function themselves.
|
||||
"""
|
||||
if loss == "passthrough":
|
||||
logger.warning(
|
||||
"No loss specified in compile() - the model's internal loss computation will be used as the "
|
||||
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
|
||||
"Please ensure your labels are passed as the 'labels' key of the input dict so that they are "
|
||||
"accessible to the model during the forward pass. To disable this behaviour, please pass a "
|
||||
"loss argument, or explicitly pass loss=None if you do not want your model to compute a loss."
|
||||
)
|
||||
loss = {"loss": dummy_loss}
|
||||
super().compile(
|
||||
optimizer=optimizer,
|
||||
loss=loss,
|
||||
metrics=metrics,
|
||||
loss_weights=loss_weights,
|
||||
weighted_metrics=weighted_metrics,
|
||||
run_eagerly=run_eagerly,
|
||||
steps_per_execution=steps_per_execution,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def train_step(self, data):
|
||||
"""
|
||||
A modification of Keras's default train_step that cleans up the printed metrics when we use a dummy loss.
|
||||
"""
|
||||
# These are the only transformations `Model.fit` applies to user-input
|
||||
# data when a `tf.data.Dataset` is provided.
|
||||
data = data_adapter.expand_1d(data)
|
||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||
# These next two lines differ from the base method - they avoid issues when the labels are in
|
||||
# the input dict (and loss is computed internally)
|
||||
if y is None and "labels" in x:
|
||||
y = x["labels"] # Stops confusion with metric computations
|
||||
# Run forward pass.
|
||||
with tf.GradientTape() as tape:
|
||||
y_pred = self(x, training=True)
|
||||
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
# Run backwards pass.
|
||||
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
# Collect metrics to return
|
||||
return_metrics = {}
|
||||
for metric in self.metrics:
|
||||
result = metric.result()
|
||||
if isinstance(result, dict):
|
||||
return_metrics.update(result)
|
||||
else:
|
||||
return_metrics[metric.name] = result
|
||||
# These next two lines are also not in the base method - they correct the displayed metrics
|
||||
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
|
||||
if "loss" in return_metrics and "loss_loss" in return_metrics:
|
||||
del return_metrics["loss_loss"]
|
||||
return return_metrics
|
||||
|
||||
def test_step(self, data):
|
||||
"""
|
||||
A modification of Keras's default test_step that cleans up the printed metrics when we use a dummy loss.
|
||||
"""
|
||||
data = data_adapter.expand_1d(data)
|
||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||
# These next two lines differ from the base method - they avoid issues when the labels are in
|
||||
# the input dict (and loss is computed internally)
|
||||
if y is None and "labels" in x:
|
||||
y = x["labels"] # Stops confusion with metric computations
|
||||
y_pred = self(x, training=False)
|
||||
if not self.loss:
|
||||
self.loss_tracker.update_state(y_pred.loss)
|
||||
return_metrics = {"loss": self.loss_tracker.result()}
|
||||
else:
|
||||
# Run anyway to update state
|
||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
return_metrics = {}
|
||||
# Updates stateful loss metrics.
|
||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
# Collect metrics to return
|
||||
for metric in self.metrics:
|
||||
result = metric.result()
|
||||
if isinstance(result, dict):
|
||||
return_metrics.update(result)
|
||||
else:
|
||||
return_metrics[metric.name] = result
|
||||
# These next two lines are also not in the base method - they correct the displayed metrics
|
||||
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
|
||||
if "loss" in return_metrics and "loss_loss" in return_metrics:
|
||||
del return_metrics["loss_loss"]
|
||||
return return_metrics
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
"""
|
||||
Set model's input embeddings
|
||||
|
|
Loading…
Reference in New Issue