exclude jit time from the speed metric calculation of evaluation and prediction (#20553)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2022-12-06 20:37:01 +08:00 committed by GitHub
parent 25e10da427
commit ae06bce888
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 42 additions and 6 deletions

View File

@ -51,10 +51,13 @@ class QuestionAnsweringTrainer(Trainer):
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
@ -74,7 +77,7 @@ class QuestionAnsweringTrainer(Trainer):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
metrics.update(output.metrics)
else:
metrics = {}
metrics = output.metrics
if self.args.should_log:
# Only the main node log the results by default
@ -103,10 +106,13 @@ class QuestionAnsweringTrainer(Trainer):
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,

View File

@ -71,10 +71,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
@ -94,9 +97,9 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
output.metrics.update(metrics)
metrics.update(output.metrics)
else:
metrics = {}
metrics = output.metrics
if self.args.should_log:
# Only the main node log the results by default
@ -106,7 +109,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics
def predict(
@ -119,6 +122,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
start_time = time.time()
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
@ -128,10 +132,22 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
finally:
self.compute_metrics = compute_metrics
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=output.num_samples,
num_steps=math.ceil(output.num_samples / total_batch_size),
)
)
if self.post_process_function is None or self.compute_metrics is None:
return output
@ -142,5 +158,5 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
metrics.update(output.metrics)
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)

View File

@ -766,6 +766,7 @@ def parse_log_history(log_history):
_ = metrics.pop("eval_runtime", None)
_ = metrics.pop("eval_samples_per_second", None)
_ = metrics.pop("eval_steps_per_second", None)
_ = metrics.pop("eval_jit_compilation_time", None)
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
for k, v in metrics.items():
if k == "eval_loss":

View File

@ -1345,7 +1345,9 @@ class Trainer:
model = nn.DataParallel(model)
if self.args.jit_mode_eval:
start_time = time.time()
model = self.torch_jit_model_eval(model, dataloader, training)
self.jit_compilation_time = round(time.time() - start_time, 4)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
@ -2819,6 +2821,8 @@ class Trainer:
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
@ -2886,6 +2890,8 @@ class Trainer:
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
output.metrics.update(
speed_metrics(
metric_key_prefix,
@ -3102,6 +3108,8 @@ class Trainer:
if all_losses is not None:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):

View File

@ -224,7 +224,11 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None)
# Remove speed metrics
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")]
speed_metrics = [
m
for m in metrics.keys()
if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time")
]
for sm in speed_metrics:
_ = metrics.pop(sm, None)
return loss if len(metrics) == 0 else sum(metrics.values())

View File

@ -339,6 +339,7 @@ class NotebookProgressCallback(TrainerCallback):
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
for k, v in metrics.items():
if k == f"{metric_key_prefix}_loss":
values["Validation Loss"] = v