Adding new train_step logic to make things less confusing for users (#15994)
* Adding new train_step logic to make things less confusing for users * DO NOT ASK WHY WE NEED THAT SUBCLASS * Metrics now working, at least for single-output models with type annotations! * Updates and TODOs for the new train_step * Make fixup * Temporary test workaround until T5 has types * Temporary test workaround until T5 has types * I think this actually works! Needs a lot of tests though * MAke style/quality * Revert changes to T5 tests * Deleting the aforementioned unmentionable subclass * Deleting the aforementioned unmentionable subclass * Adding a Keras API test * Style fixes * Removing unneeded TODO and comments * Update test_step too * Stop trying to compute metrics with the dummy_loss, patch up test * Make style * make fixup * Docstring cleanup * make fixup * make fixup * Stop expanding 1D input tensors when using dummy loss * Adjust T5 test given the new compile() * make fixup * Skipping test for convnext * Removing old T5-specific Keras test now that we have a common one * make fixup * make fixup * Only skip convnext test on CPU * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Avoiding TF import issues * make fixup * Update compile() to support TF 2.3 * Skipping model.fit() on template classes for now * Skipping model.fit() on template class tests for now * Replace ad-hoc solution with find_labels * make fixup Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
7ccacdf10f
commit
4354005291
|
@ -38,7 +38,6 @@ from .activations_tf import get_tf_activation
|
|||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation_tf_utils import TFGenerationMixin
|
||||
from .modeling_tf_outputs import TFSeq2SeqLMOutput
|
||||
from .tf_utils import shape_list
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
from .utils import (
|
||||
|
@ -53,6 +52,7 @@ from .utils import (
|
|||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
find_labels,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
|
@ -715,6 +715,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
base_model_prefix = ""
|
||||
main_input_name = "input_ids"
|
||||
_auto_class = None
|
||||
_using_dummy_loss = None
|
||||
|
||||
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
|
@ -899,24 +900,46 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
function themselves.
|
||||
"""
|
||||
if loss == "passthrough":
|
||||
if metrics is not None:
|
||||
raise ValueError(
|
||||
"Passing metrics as a dict is not supported when using the internal loss! "
|
||||
"Please either compile the model with a loss, or remove the metrics argument. "
|
||||
"Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
|
||||
"loss."
|
||||
)
|
||||
logger.warning(
|
||||
"No loss specified in compile() - the model's internal loss computation will be used as the "
|
||||
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
|
||||
"Please ensure your labels are passed as keys in the input dict so that they are "
|
||||
"accessible to the model during the forward pass. To disable this behaviour, please pass a "
|
||||
"loss argument, or explicitly pass loss=None if you do not want your model to compute a loss."
|
||||
"To disable this behaviour, please pass a loss argument, or explicitly pass "
|
||||
"`loss=None` if you do not want your model to compute a loss."
|
||||
)
|
||||
loss = dummy_loss
|
||||
self._using_dummy_loss = True
|
||||
else:
|
||||
self._using_dummy_loss = False
|
||||
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
|
||||
if "steps_per_execution" in parent_args:
|
||||
super().compile(
|
||||
optimizer=optimizer,
|
||||
loss=loss,
|
||||
metrics=metrics,
|
||||
loss_weights=loss_weights,
|
||||
weighted_metrics=weighted_metrics,
|
||||
run_eagerly=run_eagerly,
|
||||
steps_per_execution=steps_per_execution,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
super().compile(
|
||||
optimizer=optimizer,
|
||||
loss=loss,
|
||||
metrics=metrics,
|
||||
loss_weights=loss_weights,
|
||||
weighted_metrics=weighted_metrics,
|
||||
run_eagerly=run_eagerly,
|
||||
experimental_steps_per_execution=steps_per_execution,
|
||||
**kwargs,
|
||||
)
|
||||
loss = {"loss": dummy_loss}
|
||||
super().compile(
|
||||
optimizer=optimizer,
|
||||
loss=loss,
|
||||
metrics=metrics,
|
||||
loss_weights=loss_weights,
|
||||
weighted_metrics=weighted_metrics,
|
||||
run_eagerly=run_eagerly,
|
||||
steps_per_execution=steps_per_execution,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def compute_loss(self, *args, **kwargs):
|
||||
if hasattr(tf.keras.Model, "compute_loss"):
|
||||
|
@ -935,40 +958,54 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
def train_step(self, data):
|
||||
"""
|
||||
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
|
||||
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`. In
|
||||
this case, it expects the same `data` as the original function (i.e. `(inputs, labels)`).
|
||||
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`.
|
||||
|
||||
However, when the model is compiled without specifying the loss AND the expected label columns are passed as
|
||||
part of the input dictionary, the loss is computed internally (inside the model class) and is used in the
|
||||
backwards pass. In this case, `data` is a singleton tuple containing `(inputs,)`.
|
||||
|
||||
This is possible under the aforementioned circumstances because our overriden compile function can set an
|
||||
additional loss function that reduces a `loss` output, and the model will output a `loss` component (notice the
|
||||
name matching) containing the loss that was used to train the pre-trained model.
|
||||
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
|
||||
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
|
||||
as keys in the input dictionary, or as normal Keras labels.
|
||||
"""
|
||||
|
||||
# These are the only transformations `Model.fit` applies to user-input
|
||||
# data when a `tf.data.Dataset` is provided.
|
||||
data = data_adapter.expand_1d(data)
|
||||
if not self._using_dummy_loss:
|
||||
data = data_adapter.expand_1d(data)
|
||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||
# These next two lines differ from the base method - they avoid issues when the labels are in
|
||||
# the input dict (and loss is computed internally)
|
||||
if y is None and "labels" in x:
|
||||
y = x["labels"] # Stops confusion with metric computations
|
||||
elif y is None and "input_ids" in x:
|
||||
# Just make any kind of dummy array to make loss work
|
||||
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)
|
||||
|
||||
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
||||
# if those keys are not already present in the input dict
|
||||
if self._using_dummy_loss and y is not None:
|
||||
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||
label_kwargs = find_labels(self.__class__)
|
||||
# If y is a tensor and the model only has one label-like input, map y to that input
|
||||
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
||||
if isinstance(x, tf.Tensor):
|
||||
x = {arg_names[0]: x}
|
||||
label_kwarg = next(iter(label_kwargs))
|
||||
if label_kwarg not in x:
|
||||
x[label_kwarg] = y
|
||||
# Otherwise, copy keys from y to x as long as they weren't already present in x
|
||||
elif isinstance(y, dict):
|
||||
if isinstance(x, tf.Tensor):
|
||||
x = {arg_names[0]: x}
|
||||
for key, val in y.items():
|
||||
if key in arg_names and key not in x:
|
||||
x[key] = val
|
||||
|
||||
# Run forward pass.
|
||||
with tf.GradientTape() as tape:
|
||||
y_pred = self(x, training=True)
|
||||
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
if self._using_dummy_loss:
|
||||
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
||||
else:
|
||||
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
# Run backwards pass.
|
||||
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
|
||||
# When y_pred is a ModelOutput and y is a tf.Tensor the metrics update
|
||||
# should be done only with the relevant ModelOutput param that is
|
||||
# considered by the loss.
|
||||
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
|
||||
y_pred = y_pred["logits"]
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
|
||||
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
|
||||
if self._using_dummy_loss:
|
||||
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
|
||||
else:
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
# Collect metrics to return
|
||||
return_metrics = {}
|
||||
for metric in self.metrics:
|
||||
|
@ -985,23 +1022,51 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
|
||||
def test_step(self, data):
|
||||
"""
|
||||
A modification of Keras's default test_step that cleans up the printed metrics when we use a dummy loss.
|
||||
A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a
|
||||
user specifies a loss at model compile time, this function behaves as the original Keras `test_step`.
|
||||
|
||||
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
|
||||
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
|
||||
as keys in the input dictionary, or as normal Keras labels.
|
||||
"""
|
||||
data = data_adapter.expand_1d(data)
|
||||
# These are the only transformations `Model.fit` applies to user-input
|
||||
# data when a `tf.data.Dataset` is provided.
|
||||
if not self._using_dummy_loss:
|
||||
data = data_adapter.expand_1d(data)
|
||||
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
|
||||
# These next two lines differ from the base method - they avoid issues when the labels are in
|
||||
# the input dict (and loss is computed internally)
|
||||
if y is None and "labels" in x:
|
||||
y = x["labels"] # Stops confusion with metric computations
|
||||
elif y is None and "input_ids" in x:
|
||||
# Just make any kind of dummy array to make loss work
|
||||
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)
|
||||
|
||||
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
|
||||
# if those keys are not already present in the input dict
|
||||
if self._using_dummy_loss and y is not None:
|
||||
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
|
||||
label_kwargs = find_labels(self.__class__)
|
||||
# If y is a tensor and the model only has one label-like input, map y to that input
|
||||
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
|
||||
if isinstance(x, tf.Tensor):
|
||||
x = {arg_names[0]: x}
|
||||
label_kwarg = next(iter(label_kwargs))
|
||||
if label_kwarg not in x:
|
||||
x[label_kwarg] = y
|
||||
# Otherwise, copy keys from y to x as long as they weren't already present in x
|
||||
elif isinstance(y, dict):
|
||||
if isinstance(x, tf.Tensor):
|
||||
x = {arg_names[0]: x}
|
||||
for key, val in y.items():
|
||||
if key in arg_names and key not in x:
|
||||
x[key] = val
|
||||
|
||||
# Run forward pass.
|
||||
y_pred = self(x, training=False)
|
||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
# Updates stateful loss metrics.
|
||||
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
|
||||
y_pred = y_pred["logits"]
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
if self._using_dummy_loss:
|
||||
self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
|
||||
else:
|
||||
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
|
||||
|
||||
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
|
||||
if self._using_dummy_loss:
|
||||
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
|
||||
else:
|
||||
self.compiled_metrics.update_state(y, y_pred, sample_weight)
|
||||
# Collect metrics to return
|
||||
return_metrics = {}
|
||||
for metric in self.metrics:
|
||||
|
|
|
@ -259,6 +259,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
|
|||
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
|
||||
|
||||
def create_and_check_causal_lm_model_past(
|
||||
self,
|
||||
config,
|
||||
|
@ -597,6 +598,10 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Template classes interact badly with this test.")
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
|
||||
def test_causal_lm_base_model(self):
|
||||
"""Test the base model of the causal LM model
|
||||
|
||||
|
@ -947,6 +952,10 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
|
|||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
@unittest.skip(reason="Template classes interact badly with this test.")
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
|
|
|
@ -143,6 +143,13 @@ class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
|
||||
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
|
||||
)
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ConvNext does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
|
|
@ -804,33 +804,3 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
|
|||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
|
||||
self.assertEqual(translation, expected_translation)
|
||||
|
||||
def test_finetune_keras_trainer(self):
|
||||
"""Ensure that the model can be fine-tuned via the keras API and
|
||||
that metrics work as expected.
|
||||
"""
|
||||
|
||||
# This metric expects to be called with the logits output
|
||||
def _accuracy(y_true, y_pred):
|
||||
return tf.keras.metrics.sparse_categorical_crossentropy(y_true[:, 0], y_pred[:, 0])
|
||||
|
||||
# measure the accuracy of the first token
|
||||
class FirstTokenAccuracy(tf.keras.metrics.MeanMetricWrapper):
|
||||
def __init__(self, name="accuracy", **kwargs):
|
||||
super().__init__(_accuracy, name=name, **kwargs)
|
||||
|
||||
model = self.model
|
||||
model.compile("adam", metrics=FirstTokenAccuracy())
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
examples = [
|
||||
("sentiment: Everything is awesome!", "positive"),
|
||||
("sentiment: Tensorflow datasets are hard to use", "negative"),
|
||||
]
|
||||
|
||||
inputs = dict(tokenizer([x[0] for x in examples], padding=True, return_tensors="tf"))
|
||||
inputs["labels"] = tokenizer([x[1] for x in examples], return_tensors="tf").input_ids
|
||||
|
||||
model.fit(inputs)
|
||||
m = model.evaluate(inputs)
|
||||
self.assertEqual(len(m), 2)
|
||||
|
|
|
@ -1302,6 +1302,56 @@ class TFModelTesterMixin:
|
|||
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
def test_keras_fit(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
if getattr(model, "hf_compute_loss", None):
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
# Is there a better way to remove these decoder inputs?
|
||||
prepared_for_class = {
|
||||
key: val
|
||||
for key, val in prepared_for_class.items()
|
||||
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
|
||||
}
|
||||
|
||||
possible_label_cols = {
|
||||
"labels",
|
||||
"label",
|
||||
"label_ids",
|
||||
"start_positions",
|
||||
"start_position",
|
||||
"end_positions",
|
||||
"end_position",
|
||||
"next_sentence_label",
|
||||
}
|
||||
label_names = possible_label_cols.intersection(set(prepared_for_class))
|
||||
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
|
||||
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||
self.assertGreater(len(inputs_minus_labels), 0)
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
|
||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||
history1 = model.fit(
|
||||
prepared_for_class,
|
||||
validation_data=prepared_for_class,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss1 = history1.history["val_loss"][0]
|
||||
history2 = model.fit(
|
||||
inputs_minus_labels,
|
||||
labels,
|
||||
validation_data=(inputs_minus_labels, labels),
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
|
||||
def test_generate_with_headmasking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
|
Loading…
Reference in New Issue