exclude jit time from the speed metric calculation of evaluation and prediction (#20553)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
25e10da427
commit
ae06bce888
|
@ -51,10 +51,13 @@ class QuestionAnsweringTrainer(Trainer):
|
|||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
metric_key_prefix=metric_key_prefix,
|
||||
)
|
||||
finally:
|
||||
self.compute_metrics = compute_metrics
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
|
@ -74,7 +77,7 @@ class QuestionAnsweringTrainer(Trainer):
|
|||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
metrics.update(output.metrics)
|
||||
else:
|
||||
metrics = {}
|
||||
metrics = output.metrics
|
||||
|
||||
if self.args.should_log:
|
||||
# Only the main node log the results by default
|
||||
|
@ -103,10 +106,13 @@ class QuestionAnsweringTrainer(Trainer):
|
|||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
metric_key_prefix=metric_key_prefix,
|
||||
)
|
||||
finally:
|
||||
self.compute_metrics = compute_metrics
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
|
|
|
@ -71,10 +71,13 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
metric_key_prefix=metric_key_prefix,
|
||||
)
|
||||
finally:
|
||||
self.compute_metrics = compute_metrics
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
|
@ -94,9 +97,9 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
|
||||
output.metrics.update(metrics)
|
||||
metrics.update(output.metrics)
|
||||
else:
|
||||
metrics = {}
|
||||
metrics = output.metrics
|
||||
|
||||
if self.args.should_log:
|
||||
# Only the main node log the results by default
|
||||
|
@ -106,7 +109,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
||||
xm.master_print(met.metrics_report())
|
||||
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
return metrics
|
||||
|
||||
def predict(
|
||||
|
@ -119,6 +122,7 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||
# Temporarily disable metric computation, we will do it in the loop here.
|
||||
compute_metrics = self.compute_metrics
|
||||
self.compute_metrics = None
|
||||
start_time = time.time()
|
||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
||||
try:
|
||||
output = eval_loop(
|
||||
|
@ -128,10 +132,22 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
metric_key_prefix=metric_key_prefix,
|
||||
)
|
||||
finally:
|
||||
self.compute_metrics = compute_metrics
|
||||
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
start_time,
|
||||
num_samples=output.num_samples,
|
||||
num_steps=math.ceil(output.num_samples / total_batch_size),
|
||||
)
|
||||
)
|
||||
if self.post_process_function is None or self.compute_metrics is None:
|
||||
return output
|
||||
|
||||
|
@ -142,5 +158,5 @@ class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
|
|||
for key in list(metrics.keys()):
|
||||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
|
||||
metrics.update(output.metrics)
|
||||
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
|
||||
|
|
|
@ -766,6 +766,7 @@ def parse_log_history(log_history):
|
|||
_ = metrics.pop("eval_runtime", None)
|
||||
_ = metrics.pop("eval_samples_per_second", None)
|
||||
_ = metrics.pop("eval_steps_per_second", None)
|
||||
_ = metrics.pop("eval_jit_compilation_time", None)
|
||||
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
|
||||
for k, v in metrics.items():
|
||||
if k == "eval_loss":
|
||||
|
|
|
@ -1345,7 +1345,9 @@ class Trainer:
|
|||
model = nn.DataParallel(model)
|
||||
|
||||
if self.args.jit_mode_eval:
|
||||
start_time = time.time()
|
||||
model = self.torch_jit_model_eval(model, dataloader, training)
|
||||
self.jit_compilation_time = round(time.time() - start_time, 4)
|
||||
|
||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||
|
@ -2819,6 +2821,8 @@ class Trainer:
|
|||
)
|
||||
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
|
@ -2886,6 +2890,8 @@ class Trainer:
|
|||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
||||
)
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
|
||||
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
|
@ -3102,6 +3108,8 @@ class Trainer:
|
|||
|
||||
if all_losses is not None:
|
||||
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
|
||||
if hasattr(self, "jit_compilation_time"):
|
||||
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
|
||||
|
||||
# Prefix all keys with metric_key_prefix + '_'
|
||||
for key in list(metrics.keys()):
|
||||
|
|
|
@ -224,7 +224,11 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
|||
loss = metrics.pop("eval_loss", None)
|
||||
_ = metrics.pop("epoch", None)
|
||||
# Remove speed metrics
|
||||
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")]
|
||||
speed_metrics = [
|
||||
m
|
||||
for m in metrics.keys()
|
||||
if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time")
|
||||
]
|
||||
for sm in speed_metrics:
|
||||
_ = metrics.pop(sm, None)
|
||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||
|
|
|
@ -339,6 +339,7 @@ class NotebookProgressCallback(TrainerCallback):
|
|||
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
|
||||
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
|
||||
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
|
||||
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
|
||||
for k, v in metrics.items():
|
||||
if k == f"{metric_key_prefix}_loss":
|
||||
values["Validation Loss"] = v
|
||||
|
|
Loading…
Reference in New Issue