[finetune_trainer] enhancements and fixes (#9042)

* trainer and finetune_trainer enhancements and fixes

* add fallback default

* move the fixing of incorrect keys back into finetune trainer

* s/eval/val/ to match the split

* trainer can now use a different prefix than eval_ for metrics

* document new arg

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* use 'eval' as the default for metric_key_prefix

* complete adjust var names + disambiguate

* fix logger

* add clarifying comment

* add clarifying comment

* style

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/trainer.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* complete removal of optional for metric_key_prefix

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Stas Bekman 2020-12-14 17:45:33 -08:00 committed by GitHub
parent 251eb70c97
commit c19d04623e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 97 additions and 29 deletions

View File

@ -16,6 +16,7 @@
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from typing import Optional
@ -119,6 +120,46 @@ class DataTrainingArguments:
)
def speed_metrics(split, start_time, num_samples):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this
function should be run immediately after the operation to be measured has completed.
Args:
- split: one of train, val, test
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime = time.time() - start_time
result = {}
samples_per_second = 1 / (runtime / num_samples)
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
result[f"{split}_runtime"] = round(runtime, 4)
result[f"{split}_n_ojbs"] = num_samples
return result
def handle_metrics(split, metrics, output_dir):
"""
Log and save metrics
Args:
- split: one of train, val, test
- metrics: metrics dict
- output_dir: where to save the metrics
"""
logger.info(f"***** {split} metrics *****")
for key, value in metrics.items():
logger.info(f" {key} = {value}")
save_json(metrics, os.path.join(output_dir, f"{split}_results.json"))
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
@ -265,45 +306,56 @@ def main():
data_args=data_args,
)
all_metrics = {}
# Training
if training_args.do_train:
logger.info("*** Train ***")
start_time = time.time()
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
metrics = speed_metrics("train", start_time, data_args.n_train)
trainer.save_model() # this also saves the tokenizer
if trainer.is_world_process_zero():
handle_metrics("train", metrics, training_args.output_dir)
all_metrics.update(metrics)
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
tokenizer.save_pretrained(training_args.output_dir)
# Evaluation
eval_results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
start_time = time.time()
metrics = trainer.evaluate(metric_key_prefix="val")
metrics.update(speed_metrics("val", start_time, data_args.n_val))
metrics["val_loss"] = round(metrics["val_loss"], 4)
if trainer.is_world_process_zero():
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
save_json(result, os.path.join(training_args.output_dir, "eval_results.json"))
eval_results.update(result)
handle_metrics("val", metrics, training_args.output_dir)
all_metrics.update(metrics)
if training_args.do_predict:
logging.info("*** Test ***")
logger.info("*** Predict ***")
test_output = trainer.predict(test_dataset=test_dataset)
test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()}
start_time = time.time()
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
metrics = test_output.metrics
metrics.update(speed_metrics("test", start_time, data_args.n_test))
if trainer.is_world_process_zero():
logger.info("***** Test results *****")
for key, value in test_metrics.items():
logger.info(" %s = %s", key, value)
save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json"))
eval_results.update(test_metrics)
metrics["test_loss"] = round(metrics["test_loss"], 4)
handle_metrics("test", metrics, training_args.output_dir)
all_metrics.update(metrics)
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(
@ -313,8 +365,9 @@ def main():
write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt"))
if trainer.is_world_process_zero():
save_json(eval_results, "all_results.json")
return eval_results
save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json"))
return all_metrics
def _mp_fn(index):

View File

@ -462,7 +462,7 @@ def save_git_info(folder_path: str) -> None:
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, **json_dump_kwargs)
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
def load_json(path):

View File

@ -1243,7 +1243,10 @@ class Trainer:
shutil.rmtree(checkpoint)
def evaluate(
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
@ -1261,6 +1264,9 @@ class Trainer:
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
@ -1278,6 +1284,7 @@ class Trainer:
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
self.log(output.metrics)
@ -1289,7 +1296,9 @@ class Trainer:
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics
def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval"
) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
@ -1303,6 +1312,9 @@ class Trainer:
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is "eval" (default)
.. note::
@ -1322,7 +1334,9 @@ class Trainer:
test_dataloader = self.get_test_dataloader(test_dataset)
return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
return self.prediction_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
def prediction_loop(
self,
@ -1330,6 +1344,7 @@ class Trainer:
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
@ -1421,12 +1436,12 @@ class Trainer:
metrics = {}
if eval_loss is not None:
metrics["eval_loss"] = eval_loss.mean().item()
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
# Prefix all keys with eval_
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith("eval_"):
metrics[f"eval_{key}"] = metrics.pop(key)
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)