Keras metric callback (#14867)
* Working on splitting out labels * First working version * Fixed concatenation of outputs and labels * val_dataset -> eval_dataset * Only pass input arrays in tokenizer.model_input_names * Only pass input arrays in tokenizer.model_input_names * Only remove unexpected keys when predict_with_generate is True * Adding proper docstring * Adding example to docstring * Add a proper ROUGE metric example * Add a proper ROUGE metric example * Add version checking * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove requirement for tokenizer with predict_with_generate Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
fa39ff9fc4
commit
b0c7d2ec58
|
@ -2,8 +2,11 @@ import logging
|
|||
import os
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from packaging.version import parse
|
||||
from tensorflow.keras.callbacks import Callback
|
||||
|
||||
from huggingface_hub import Repository
|
||||
|
@ -16,6 +19,184 @@ from .modelcard import TrainingSummary
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KerasMetricCallback(Callback):
|
||||
"""
|
||||
Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
|
||||
compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
|
||||
operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the
|
||||
`eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute
|
||||
metrics and return a dict mapping metric names to metric values.
|
||||
|
||||
We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below.
|
||||
Note that this example skips some post-processing for readability and simplicity, and should probably
|
||||
not be used as-is!
|
||||
|
||||
```py
|
||||
from datasets import load_metric
|
||||
rouge_metric = load_metric("rouge")
|
||||
|
||||
def rouge_fn(predictions, labels):
|
||||
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True))
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
|
||||
return {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||
```
|
||||
|
||||
The above function will return a dict containing values which will be logged like any other Keras metric:
|
||||
|
||||
```
|
||||
{'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781
|
||||
```
|
||||
|
||||
Args:
|
||||
metric_fn (`Callable`):
|
||||
Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`.
|
||||
These contain the model's outputs and matching labels from the dataset. It should return a dict mapping
|
||||
metric names to numerical values.
|
||||
eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
|
||||
Validation data to be used to generate predictions for the `metric_fn`.
|
||||
metric_fn_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to be passed to the metric_fn.
|
||||
tokenizer ([`PretrainedTokenizerBase`], *optional*):
|
||||
Tokenizer used to validate column names to be passed to the generate() function.
|
||||
output_cols (`List[str], *optional*):
|
||||
A list of columns to be retained from the model output as the predictions. Defaults to all.
|
||||
label_cols ('`List[str]`, *optional*'):
|
||||
A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not
|
||||
supplied.
|
||||
batch_size (`int`, *optional*):
|
||||
Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
|
||||
predict_with_generate: (`bool`, *optional*, defaults to `False`):
|
||||
Whether we should use `model.generate()` to get outputs for the model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metric_fn: Callable,
|
||||
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
metric_fn_kwargs: Optional[dict] = None,
|
||||
output_cols: Optional[List[str]] = None,
|
||||
label_cols: Optional[List[str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
predict_with_generate: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.metric_fn = metric_fn
|
||||
self.batch_size = batch_size
|
||||
if not isinstance(eval_dataset, tf.data.Dataset):
|
||||
if batch_size is None:
|
||||
raise ValueError(
|
||||
"When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset "
|
||||
"the batch_size argument must be set."
|
||||
)
|
||||
# Wrap a tf.data.Dataset around it
|
||||
eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False)
|
||||
self.eval_dataset = eval_dataset
|
||||
self.predict_with_generate = predict_with_generate
|
||||
self.output_cols = output_cols
|
||||
self.metric_fn_kwargs = metric_fn_kwargs or dict()
|
||||
if tokenizer is not None:
|
||||
self.model_input_names = tokenizer.model_input_names
|
||||
else:
|
||||
self.model_input_names = ["input_ids"]
|
||||
|
||||
# This next block attempts to parse out which elements of the dataset should be appended to the labels list
|
||||
# that is passed to the metric_fn
|
||||
if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2:
|
||||
input_spec, label_spec = eval_dataset.element_spec
|
||||
else:
|
||||
input_spec = eval_dataset.element_spec
|
||||
label_spec = None
|
||||
if label_cols is not None:
|
||||
for label in label_cols:
|
||||
if label not in input_spec:
|
||||
raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!")
|
||||
self.label_cols = label_cols
|
||||
self.use_keras_label = False
|
||||
elif label_spec is not None:
|
||||
# If the dataset inputs are split into a 2-tuple of inputs and labels,
|
||||
# assume the second element is the labels
|
||||
self.label_cols = None
|
||||
self.use_keras_label = True
|
||||
elif "labels" in input_spec:
|
||||
self.label_cols = ["labels"]
|
||||
self.use_keras_label = False
|
||||
logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
|
||||
else:
|
||||
raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
|
||||
if parse(tf.__version__).minor < parse("2.7"):
|
||||
logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
|
||||
|
||||
@staticmethod
|
||||
def _concatenate_batches(batches):
|
||||
# Flattens Numpy array batches into a list of single samples, where each sample is still np.ndarray
|
||||
return [sample for batch in batches for sample in batch]
|
||||
|
||||
def _postprocess_predictions_or_labels(self, inputs):
|
||||
if isinstance(inputs[0], dict):
|
||||
outputs = dict()
|
||||
for key in inputs[0].keys():
|
||||
outputs[key] = self._concatenate_batches(batch[key] for batch in inputs)
|
||||
elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
|
||||
outputs = []
|
||||
for input_list in zip(*inputs):
|
||||
outputs.append(self._concatenate_batches(input_list))
|
||||
elif isinstance(inputs[0], np.ndarray):
|
||||
outputs = self._concatenate_batches(inputs)
|
||||
else:
|
||||
raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
|
||||
return outputs
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
prediction_list = []
|
||||
label_list = []
|
||||
|
||||
# The whole predict/generate loop is handled inside this method
|
||||
for batch in self.eval_dataset:
|
||||
if isinstance(batch, tuple):
|
||||
batch, labels = batch
|
||||
else:
|
||||
labels = None
|
||||
if self.predict_with_generate:
|
||||
if isinstance(batch, dict):
|
||||
# generate() gets stressed out by any unexpected keys
|
||||
batch = {key: array for key, array in batch.items() if key in self.model_input_names}
|
||||
predictions = self.model.generate(batch)
|
||||
else:
|
||||
predictions = self.model.predict(batch)
|
||||
predictions = dict(predictions)
|
||||
if self.output_cols is not None:
|
||||
predictions = {key: predictions[key] for key in self.output_cols}
|
||||
prediction_list.append(predictions)
|
||||
if not self.use_keras_label:
|
||||
labels = {key: batch[key].numpy() for key in self.label_cols}
|
||||
elif isinstance(labels, dict):
|
||||
labels = {key: array.numpy() for key, array in labels.items()}
|
||||
elif isinstance(labels, list) or isinstance(labels, tuple):
|
||||
labels = [array.numpy() for array in labels]
|
||||
elif isinstance(labels, tf.Tensor):
|
||||
labels = labels.numpy()
|
||||
else:
|
||||
raise TypeError(f"Confused by labels of type {type(labels)}")
|
||||
label_list.append(labels)
|
||||
|
||||
prediction_list = self._postprocess_predictions_or_labels(prediction_list)
|
||||
label_list = self._postprocess_predictions_or_labels(label_list)
|
||||
|
||||
metric_output = self.metric_fn(prediction_list, label_list, **self.metric_fn_kwargs)
|
||||
if not isinstance(metric_output, dict):
|
||||
raise TypeError(
|
||||
f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
|
||||
)
|
||||
# This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch
|
||||
# in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of
|
||||
# new keys in there, which will then get read by the History callback and treated like any other metric value.
|
||||
# I promise that I have it in writing from Chollet that this is okay.
|
||||
logs.update(metric_output)
|
||||
|
||||
|
||||
class PushToHubCallback(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue