[s2s] Fix t5 warning for distributed eval (#7487)
This commit is contained in:
parent
4c6728460a
commit
6fe8a693eb
|
@ -42,8 +42,7 @@ def eval_data_dir(
|
|||
task="summarization",
|
||||
local_rank=None,
|
||||
num_return_sequences=1,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
dataset_kwargs: Dict = None,
|
||||
prefix="",
|
||||
**generate_kwargs,
|
||||
) -> Dict:
|
||||
|
@ -78,9 +77,8 @@ def eval_data_dir(
|
|||
max_target_length=1024,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
prefix=prefix,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
# I set shuffle=True for a more accurate progress bar.
|
||||
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
|
||||
|
@ -158,6 +156,11 @@ def run_generate():
|
|||
if intermediate_files:
|
||||
raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
|
||||
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
|
||||
dataset_kwargs = {}
|
||||
if args.src_lang is not None:
|
||||
dataset_kwargs["src_lang"] = args.src_lang
|
||||
if args.tgt_lang is not None:
|
||||
dataset_kwargs["tgt_lang"] = args.tgt_lang
|
||||
|
||||
Path(args.save_dir).mkdir(exist_ok=True)
|
||||
results, num_replicas = eval_data_dir(
|
||||
|
@ -173,9 +176,7 @@ def run_generate():
|
|||
max_source_length=args.max_source_length,
|
||||
num_return_sequences=args.num_return_sequences,
|
||||
prefix=args.prefix,
|
||||
src_lang=args.src_lang,
|
||||
tgt_lang=args.tgt_lang,
|
||||
**generate_kwargs,
|
||||
dataset_kwargs=dataset_kwargs ** generate_kwargs,
|
||||
)
|
||||
|
||||
if args.local_rank <= 0:
|
||||
|
|
Loading…
Reference in New Issue