[s2s] run_eval/run_eval_search tweaks (#7192)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Stas Bekman 2020-09-17 11:26:38 -07:00 committed by GitHub
parent 9c5bcab5b0
commit efeab6a3f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 48 deletions

View File

@ -15,7 +15,6 @@ except ImportError:
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = {
"translation": ["bleu"],
"translation_en_to_de": ["bleu"],
"summarization": ["rouge1", "rouge2", "rougeL"],
}
@ -66,9 +65,7 @@ def run_search():
parser.add_argument(
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
)
parser.add_argument(
"--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys()
)
parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
parser.add_argument(
"--info",
nargs="?",
@ -81,8 +78,11 @@ def run_search():
args_main.extend(["--task", args.task])
args_normal = [prog] + args_main
# to support variations like translation_en_to_de"
task = "translation" if "translation" in args.task else "summarization"
matrix, col_names = parse_search_arg(args.search)
col_names[0:0] = task_score_names[args.task] # score cols first
col_names[0:0] = task_score_names[task] # score cols first
col_widths = {col: len(str(col)) for col in col_names}
results = []
for r in matrix:
@ -96,7 +96,7 @@ def run_search():
scores = run_generate(verbose=False)
# make sure scores are first in the table
result = OrderedDict()
for score in task_score_names[args.task]:
for score in task_score_names[task]:
result[score] = scores[score]
result.update(hparams)
results.append(result)
@ -107,14 +107,14 @@ def run_search():
if l > col_widths[k]:
col_widths[k] = l
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True)
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
for row in results_sorted:
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
best = results_sorted[0]
for score in task_score_names[args.task]:
for score in task_score_names[task]:
del best[score]
best_args = [f"--{k} {v}" for k, v in best.items()]
dyn_args = ["--bs", str(args.bs)]

View File

@ -106,6 +106,9 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM = "google/pegasus-xsum"
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
@ -284,8 +287,7 @@ class TestSummarizationDistiller(unittest.TestCase):
return model
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval(model):
def run_eval_tester(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists()
@ -293,28 +295,39 @@ def test_run_eval(model):
_dump_articles(input_file_name, articles)
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [
"run_eval.py",
model,
str(input_file_name),
str(output_file_name),
"--score_path",
score_path,
"--task",
task,
"--num_beams",
"2",
"--length_penalty",
"2.0",
]
testargs = f"""
run_eval_search.py
{model}
{input_file_name}
{output_file_name}
--score_path {score_path}
--task {task}
--num_beams 2
--length_penalty 2.0
""".split()
with patch.object(sys, "argv", testargs):
run_generate()
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def test_run_eval():
run_eval_tester(T5_TINY)
# any extra models should go into the list here - can be slow
@slow
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)])
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY])
def test_run_eval_slow(model):
run_eval_tester(model)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@slow
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY])
def test_run_eval_search(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
@ -335,20 +348,17 @@ def test_run_eval_search(model):
_dump_articles(input_file_name, text["en"])
_dump_articles(reference_path, text["de"])
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [
"run_eval_search.py",
model,
str(input_file_name),
str(output_file_name),
"--score_path",
score_path,
"--reference_path",
reference_path,
"--task",
task,
"--search",
"num_beams=1:2 length_penalty=0.9:1.0",
]
testargs = f"""
run_eval_search.py
--model_name {model}
--data_dir {str(input_file_name)}
--save_dir {str(output_file_name)}
--score_path {score_path}
--reference_path {reference_path},
--task {task}
--search num_beams=1:2 length_penalty=0.9:1.0
""".split()
with patch.object(sys, "argv", testargs):
with CaptureStdout() as cs:
run_search()
@ -367,8 +377,8 @@ def test_run_eval_search(model):
@pytest.mark.parametrize(
["model"],
[pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
"model",
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY],
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
@ -541,13 +551,13 @@ def test_pack_dataset():
@pytest.mark.parametrize(
["tok_name"],
"tok_name",
[
pytest.param(MBART_TINY),
pytest.param(MARIAN_TINY),
pytest.param(T5_TINY),
pytest.param(BART_TINY),
pytest.param("google/pegasus-xsum"),
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_seq2seq_dataset_truncation(tok_name):
@ -589,7 +599,7 @@ def test_seq2seq_dataset_truncation(tok_name):
break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")])
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()