From 6fe8a693ebbfa6e70b880f7c24e0cf524be6fb25 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 30 Sep 2020 16:58:03 -0400 Subject: [PATCH] [s2s] Fix t5 warning for distributed eval (#7487) --- examples/seq2seq/run_distributed_eval.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/seq2seq/run_distributed_eval.py b/examples/seq2seq/run_distributed_eval.py index 4379836cb5..e1132d56d9 100755 --- a/examples/seq2seq/run_distributed_eval.py +++ b/examples/seq2seq/run_distributed_eval.py @@ -42,8 +42,7 @@ def eval_data_dir( task="summarization", local_rank=None, num_return_sequences=1, - src_lang=None, - tgt_lang=None, + dataset_kwargs: Dict = None, prefix="", **generate_kwargs, ) -> Dict: @@ -78,9 +77,8 @@ def eval_data_dir( max_target_length=1024, type_path=type_path, n_obs=n_obs, - src_lang=src_lang, - tgt_lang=tgt_lang, prefix=prefix, + **dataset_kwargs, ) # I set shuffle=True for a more accurate progress bar. # If all the longest samples are first, the prog bar estimate is too high at the beginning. @@ -158,6 +156,11 @@ def run_generate(): if intermediate_files: raise ValueError(f"Found files at {json_save_dir} please move or remove them.") # In theory, a node could finish and save before another node hits this. If this happens, we can address later. + dataset_kwargs = {} + if args.src_lang is not None: + dataset_kwargs["src_lang"] = args.src_lang + if args.tgt_lang is not None: + dataset_kwargs["tgt_lang"] = args.tgt_lang Path(args.save_dir).mkdir(exist_ok=True) results, num_replicas = eval_data_dir( @@ -173,9 +176,7 @@ def run_generate(): max_source_length=args.max_source_length, num_return_sequences=args.num_return_sequences, prefix=args.prefix, - src_lang=args.src_lang, - tgt_lang=args.tgt_lang, - **generate_kwargs, + dataset_kwargs=dataset_kwargs ** generate_kwargs, ) if args.local_rank <= 0: