Modified TF train_step (#13678)

Allows models to be compiled without a loss, and to use the internal loss computations for training with fit()
This commit is contained in:
Matt 2021-09-27 14:47:07 +01:00 committed by GitHub
parent e00bc7cd2f
commit 367c2ef53b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 108 additions and 2 deletions

View File

@ -26,6 +26,7 @@ import h5py
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig
@ -54,6 +55,10 @@ TFModelInputType = Union[
]
def dummy_loss(y_true, y_pred):
return tf.reduce_mean(y_pred)
class TFModelUtilsMixin:
"""
A few utilities for :obj:`tf.keras.Model`, to be used as a mixin.
@ -296,8 +301,7 @@ def booleans_processing(config, **kwargs):
if (
kwargs["output_attentions"] not in (None, config.output_attentions)
or kwargs["output_hidden_states"] not in (None, config.output_hidden_states)
or "use_cache" in kwargs
and kwargs["use_cache"] not in (None, config.use_cache)
or ("use_cache" in kwargs and kwargs["use_cache"] not in (None, config.use_cache))
):
tf_logger.warning(
"The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model."
@ -712,6 +716,108 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else:
raise NotImplementedError
def compile(
self,
optimizer="rmsprop",
loss="passthrough",
metrics=None,
loss_weights=None,
weighted_metrics=None,
run_eagerly=None,
steps_per_execution=None,
**kwargs
):
"""
This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
function themselves.
"""
if loss == "passthrough":
logger.warning(
"No loss specified in compile() - the model's internal loss computation will be used as the "
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
"Please ensure your labels are passed as the 'labels' key of the input dict so that they are "
"accessible to the model during the forward pass. To disable this behaviour, please pass a "
"loss argument, or explicitly pass loss=None if you do not want your model to compute a loss."
)
loss = {"loss": dummy_loss}
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
**kwargs,
)
def train_step(self, data):
"""
A modification of Keras's default train_step that cleans up the printed metrics when we use a dummy loss.
"""
# These are the only transformations `Model.fit` applies to user-input
# data when a `tf.data.Dataset` is provided.
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# These next two lines differ from the base method - they avoid issues when the labels are in
# the input dict (and loss is computed internally)
if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations
# Run forward pass.
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return
return_metrics = {}
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
# These next two lines are also not in the base method - they correct the displayed metrics
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
if "loss" in return_metrics and "loss_loss" in return_metrics:
del return_metrics["loss_loss"]
return return_metrics
def test_step(self, data):
"""
A modification of Keras's default test_step that cleans up the printed metrics when we use a dummy loss.
"""
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# These next two lines differ from the base method - they avoid issues when the labels are in
# the input dict (and loss is computed internally)
if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations
y_pred = self(x, training=False)
if not self.loss:
self.loss_tracker.update_state(y_pred.loss)
return_metrics = {"loss": self.loss_tracker.result()}
else:
# Run anyway to update state
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
return_metrics = {}
# Updates stateful loss metrics.
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return
for metric in self.metrics:
result = metric.result()
if isinstance(result, dict):
return_metrics.update(result)
else:
return_metrics[metric.name] = result
# These next two lines are also not in the base method - they correct the displayed metrics
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
if "loss" in return_metrics and "loss_loss" in return_metrics:
del return_metrics["loss_loss"]
return return_metrics
def set_input_embeddings(self, value):
"""
Set model's input embeddings