Adding new train_step logic to make things less confusing for users (#15994)

* Adding new train_step logic to make things less confusing for users

* DO NOT ASK WHY WE NEED THAT SUBCLASS

* Metrics now working, at least for single-output models with type annotations!

* Updates and TODOs for the new train_step

* Make fixup

* Temporary test workaround until T5 has types

* Temporary test workaround until T5 has types

* I think this actually works! Needs a lot of tests though

* MAke style/quality

* Revert changes to T5 tests

* Deleting the aforementioned unmentionable subclass

* Deleting the aforementioned unmentionable subclass

* Adding a Keras API test

* Style fixes

* Removing unneeded TODO and comments

* Update test_step too

* Stop trying to compute metrics with the dummy_loss, patch up test

* Make style

* make fixup

* Docstring cleanup

* make fixup

* make fixup

* Stop expanding 1D input tensors when using dummy loss

* Adjust T5 test given the new compile()

* make fixup

* Skipping test for convnext

* Removing old T5-specific Keras test now that we have a common one

* make fixup

* make fixup

* Only skip convnext test on CPU

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Avoiding TF import issues

* make fixup

* Update compile() to support TF 2.3

* Skipping model.fit() on template classes for now

* Skipping model.fit() on template class tests for now

* Replace ad-hoc solution with find_labels

* make fixup

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Matt 2022-04-05 14:23:27 +01:00 committed by GitHub
parent 7ccacdf10f
commit 4354005291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 184 additions and 83 deletions

View File

@ -38,7 +38,6 @@ from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding
from .utils import (
@ -53,6 +52,7 @@ from .utils import (
RevisionNotFoundError,
cached_path,
copy_func,
find_labels,
has_file,
hf_bucket_url,
is_offline_mode,
@ -715,6 +715,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
base_model_prefix = ""
main_input_name = "input_ids"
_auto_class = None
_using_dummy_loss = None
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
@ -899,24 +900,46 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
function themselves.
"""
if loss == "passthrough":
if metrics is not None:
raise ValueError(
"Passing metrics as a dict is not supported when using the internal loss! "
"Please either compile the model with a loss, or remove the metrics argument. "
"Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
"loss."
)
logger.warning(
"No loss specified in compile() - the model's internal loss computation will be used as the "
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
"Please ensure your labels are passed as keys in the input dict so that they are "
"accessible to the model during the forward pass. To disable this behaviour, please pass a "
"loss argument, or explicitly pass loss=None if you do not want your model to compute a loss."
"To disable this behaviour, please pass a loss argument, or explicitly pass "
"`loss=None` if you do not want your model to compute a loss."
)
loss = dummy_loss
self._using_dummy_loss = True
else:
self._using_dummy_loss = False
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
if "steps_per_execution" in parent_args:
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
**kwargs,
)
else:
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
experimental_steps_per_execution=steps_per_execution,
**kwargs,
)
loss = {"loss": dummy_loss}
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
**kwargs,
)
def compute_loss(self, *args, **kwargs):
if hasattr(tf.keras.Model, "compute_loss"):
@ -935,40 +958,54 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
def train_step(self, data):
"""
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`. In
this case, it expects the same `data` as the original function (i.e. `(inputs, labels)`).
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`.
However, when the model is compiled without specifying the loss AND the expected label columns are passed as
part of the input dictionary, the loss is computed internally (inside the model class) and is used in the
backwards pass. In this case, `data` is a singleton tuple containing `(inputs,)`.
This is possible under the aforementioned circumstances because our overriden compile function can set an
additional loss function that reduces a `loss` output, and the model will output a `loss` component (notice the
name matching) containing the loss that was used to train the pre-trained model.
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
as keys in the input dictionary, or as normal Keras labels.
"""
# These are the only transformations `Model.fit` applies to user-input
# data when a `tf.data.Dataset` is provided.
data = data_adapter.expand_1d(data)
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# These next two lines differ from the base method - they avoid issues when the labels are in
# the input dict (and loss is computed internally)
if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations
elif y is None and "input_ids" in x:
# Just make any kind of dummy array to make loss work
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
# If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor):
x = {arg_names[0]: x}
label_kwarg = next(iter(label_kwargs))
if label_kwarg not in x:
x[label_kwarg] = y
# Otherwise, copy keys from y to x as long as they weren't already present in x
elif isinstance(y, dict):
if isinstance(x, tf.Tensor):
x = {arg_names[0]: x}
for key, val in y.items():
if key in arg_names and key not in x:
x[key] = val
# Run forward pass.
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
if self._using_dummy_loss:
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
else:
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
# When y_pred is a ModelOutput and y is a tf.Tensor the metrics update
# should be done only with the relevant ModelOutput param that is
# considered by the loss.
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
y_pred = y_pred["logits"]
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
if self._using_dummy_loss:
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
else:
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return
return_metrics = {}
for metric in self.metrics:
@ -985,23 +1022,51 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
def test_step(self, data):
"""
A modification of Keras's default test_step that cleans up the printed metrics when we use a dummy loss.
A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a
user specifies a loss at model compile time, this function behaves as the original Keras `test_step`.
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
as keys in the input dictionary, or as normal Keras labels.
"""
data = data_adapter.expand_1d(data)
# These are the only transformations `Model.fit` applies to user-input
# data when a `tf.data.Dataset` is provided.
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# These next two lines differ from the base method - they avoid issues when the labels are in
# the input dict (and loss is computed internally)
if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations
elif y is None and "input_ids" in x:
# Just make any kind of dummy array to make loss work
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
# If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor):
x = {arg_names[0]: x}
label_kwarg = next(iter(label_kwargs))
if label_kwarg not in x:
x[label_kwarg] = y
# Otherwise, copy keys from y to x as long as they weren't already present in x
elif isinstance(y, dict):
if isinstance(x, tf.Tensor):
x = {arg_names[0]: x}
for key, val in y.items():
if key in arg_names and key not in x:
x[key] = val
# Run forward pass.
y_pred = self(x, training=False)
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Updates stateful loss metrics.
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
y_pred = y_pred["logits"]
self.compiled_metrics.update_state(y, y_pred, sample_weight)
if self._using_dummy_loss:
self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
else:
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
if self._using_dummy_loss:
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
else:
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return
return_metrics = {}
for metric in self.metrics:

View File

@ -259,6 +259,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_causal_lm_model_past(
self,
config,
@ -597,6 +598,10 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Template classes interact badly with this test.")
def test_keras_fit(self):
pass
def test_causal_lm_base_model(self):
"""Test the base model of the causal LM model
@ -947,6 +952,10 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
models_equal = False
self.assertTrue(models_equal)
@unittest.skip(reason="Template classes interact badly with this test.")
def test_keras_fit(self):
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""

View File

@ -143,6 +143,13 @@ class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
)
def test_keras_fit(self):
pass
@unittest.skip(reason="ConvNext does not support input and output embeddings")
def test_model_common_attributes(self):
pass

View File

@ -804,33 +804,3 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(translation, expected_translation)
def test_finetune_keras_trainer(self):
"""Ensure that the model can be fine-tuned via the keras API and
that metrics work as expected.
"""
# This metric expects to be called with the logits output
def _accuracy(y_true, y_pred):
return tf.keras.metrics.sparse_categorical_crossentropy(y_true[:, 0], y_pred[:, 0])
# measure the accuracy of the first token
class FirstTokenAccuracy(tf.keras.metrics.MeanMetricWrapper):
def __init__(self, name="accuracy", **kwargs):
super().__init__(_accuracy, name=name, **kwargs)
model = self.model
model.compile("adam", metrics=FirstTokenAccuracy())
tokenizer = T5Tokenizer.from_pretrained("t5-small")
examples = [
("sentiment: Everything is awesome!", "positive"),
("sentiment: Tensorflow datasets are hard to use", "negative"),
]
inputs = dict(tokenizer([x[0] for x in examples], padding=True, return_tensors="tf"))
inputs["labels"] = tokenizer([x[1] for x in examples], return_tensors="tf").input_ids
model.fit(inputs)
m = model.evaluate(inputs)
self.assertEqual(len(m), 2)

View File

@ -1302,6 +1302,56 @@ class TFModelTesterMixin:
self.assertEqual(loss.shape, [loss_size])
def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if getattr(model, "hf_compute_loss", None):
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Is there a better way to remove these decoder inputs?
prepared_for_class = {
key: val
for key, val in prepared_for_class.items()
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
}
possible_label_cols = {
"labels",
"label",
"label_ids",
"start_positions",
"start_position",
"end_positions",
"end_position",
"next_sentence_label",
}
label_names = possible_label_cols.intersection(set(prepared_for_class))
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
history2 = model.fit(
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()