[s2s] Fix t5 warning for distributed eval (#7487)

This commit is contained in:
Sam Shleifer 2020-09-30 16:58:03 -04:00 committed by GitHub
parent 4c6728460a
commit 6fe8a693eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 7 deletions

View File

@ -42,8 +42,7 @@ def eval_data_dir(
task="summarization",
local_rank=None,
num_return_sequences=1,
src_lang=None,
tgt_lang=None,
dataset_kwargs: Dict = None,
prefix="",
**generate_kwargs,
) -> Dict:
@ -78,9 +77,8 @@ def eval_data_dir(
max_target_length=1024,
type_path=type_path,
n_obs=n_obs,
src_lang=src_lang,
tgt_lang=tgt_lang,
prefix=prefix,
**dataset_kwargs,
)
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
@ -158,6 +156,11 @@ def run_generate():
if intermediate_files:
raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
dataset_kwargs = {}
if args.src_lang is not None:
dataset_kwargs["src_lang"] = args.src_lang
if args.tgt_lang is not None:
dataset_kwargs["tgt_lang"] = args.tgt_lang
Path(args.save_dir).mkdir(exist_ok=True)
results, num_replicas = eval_data_dir(
@ -173,9 +176,7 @@ def run_generate():
max_source_length=args.max_source_length,
num_return_sequences=args.num_return_sequences,
prefix=args.prefix,
src_lang=args.src_lang,
tgt_lang=args.tgt_lang,
**generate_kwargs,
dataset_kwargs=dataset_kwargs ** generate_kwargs,
)
if args.local_rank <= 0: