[s2s] run_eval/run_eval_search tweaks (#7192)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
parent
9c5bcab5b0
commit
efeab6a3f1
|
@ -15,7 +15,6 @@ except ImportError:
|
|||
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
|
||||
task_score_names = {
|
||||
"translation": ["bleu"],
|
||||
"translation_en_to_de": ["bleu"],
|
||||
"summarization": ["rouge1", "rouge2", "rougeL"],
|
||||
}
|
||||
|
||||
|
@ -66,9 +65,7 @@ def run_search():
|
|||
parser.add_argument(
|
||||
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys()
|
||||
)
|
||||
parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
|
||||
parser.add_argument(
|
||||
"--info",
|
||||
nargs="?",
|
||||
|
@ -81,8 +78,11 @@ def run_search():
|
|||
args_main.extend(["--task", args.task])
|
||||
args_normal = [prog] + args_main
|
||||
|
||||
# to support variations like translation_en_to_de"
|
||||
task = "translation" if "translation" in args.task else "summarization"
|
||||
|
||||
matrix, col_names = parse_search_arg(args.search)
|
||||
col_names[0:0] = task_score_names[args.task] # score cols first
|
||||
col_names[0:0] = task_score_names[task] # score cols first
|
||||
col_widths = {col: len(str(col)) for col in col_names}
|
||||
results = []
|
||||
for r in matrix:
|
||||
|
@ -96,7 +96,7 @@ def run_search():
|
|||
scores = run_generate(verbose=False)
|
||||
# make sure scores are first in the table
|
||||
result = OrderedDict()
|
||||
for score in task_score_names[args.task]:
|
||||
for score in task_score_names[task]:
|
||||
result[score] = scores[score]
|
||||
result.update(hparams)
|
||||
results.append(result)
|
||||
|
@ -107,14 +107,14 @@ def run_search():
|
|||
if l > col_widths[k]:
|
||||
col_widths[k] = l
|
||||
|
||||
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True)
|
||||
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
|
||||
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
|
||||
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
|
||||
for row in results_sorted:
|
||||
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
|
||||
|
||||
best = results_sorted[0]
|
||||
for score in task_score_names[args.task]:
|
||||
for score in task_score_names[task]:
|
||||
del best[score]
|
||||
best_args = [f"--{k} {v}" for k, v in best.items()]
|
||||
dyn_args = ["--bs", str(args.bs)]
|
||||
|
|
|
@ -106,6 +106,9 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
|
|||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
BERT_BASE_CASED = "bert-base-cased"
|
||||
PEGASUS_XSUM = "google/pegasus-xsum"
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
|
@ -284,8 +287,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
def test_run_eval(model):
|
||||
def run_eval_tester(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
|
@ -293,28 +295,39 @@ def test_run_eval(model):
|
|||
_dump_articles(input_file_name, articles)
|
||||
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = [
|
||||
"run_eval.py",
|
||||
model,
|
||||
str(input_file_name),
|
||||
str(output_file_name),
|
||||
"--score_path",
|
||||
score_path,
|
||||
"--task",
|
||||
task,
|
||||
"--num_beams",
|
||||
"2",
|
||||
"--length_penalty",
|
||||
"2.0",
|
||||
]
|
||||
testargs = f"""
|
||||
run_eval_search.py
|
||||
{model}
|
||||
{input_file_name}
|
||||
{output_file_name}
|
||||
--score_path {score_path}
|
||||
--task {task}
|
||||
--num_beams 2
|
||||
--length_penalty 2.0
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
|
||||
# test one model to quickly (no-@slow) catch simple problems and do an
|
||||
# extensive testing of functionality with multiple models as @slow separately
|
||||
def test_run_eval():
|
||||
run_eval_tester(T5_TINY)
|
||||
|
||||
|
||||
# any extra models should go into the list here - can be slow
|
||||
@slow
|
||||
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)])
|
||||
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY])
|
||||
def test_run_eval_slow(model):
|
||||
run_eval_tester(model)
|
||||
|
||||
|
||||
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
|
||||
@slow
|
||||
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY])
|
||||
def test_run_eval_search(model):
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
|
@ -335,20 +348,17 @@ def test_run_eval_search(model):
|
|||
_dump_articles(input_file_name, text["en"])
|
||||
_dump_articles(reference_path, text["de"])
|
||||
task = "translation_en_to_de" if model == T5_TINY else "summarization"
|
||||
testargs = [
|
||||
"run_eval_search.py",
|
||||
model,
|
||||
str(input_file_name),
|
||||
str(output_file_name),
|
||||
"--score_path",
|
||||
score_path,
|
||||
"--reference_path",
|
||||
reference_path,
|
||||
"--task",
|
||||
task,
|
||||
"--search",
|
||||
"num_beams=1:2 length_penalty=0.9:1.0",
|
||||
]
|
||||
testargs = f"""
|
||||
run_eval_search.py
|
||||
--model_name {model}
|
||||
--data_dir {str(input_file_name)}
|
||||
--save_dir {str(output_file_name)}
|
||||
--score_path {score_path}
|
||||
--reference_path {reference_path},
|
||||
--task {task}
|
||||
--search num_beams=1:2 length_penalty=0.9:1.0
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
with CaptureStdout() as cs:
|
||||
run_search()
|
||||
|
@ -367,8 +377,8 @@ def test_run_eval_search(model):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model"],
|
||||
[pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
|
||||
"model",
|
||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY],
|
||||
)
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
@ -541,13 +551,13 @@ def test_pack_dataset():
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["tok_name"],
|
||||
"tok_name",
|
||||
[
|
||||
pytest.param(MBART_TINY),
|
||||
pytest.param(MARIAN_TINY),
|
||||
pytest.param(T5_TINY),
|
||||
pytest.param(BART_TINY),
|
||||
pytest.param("google/pegasus-xsum"),
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
def test_seq2seq_dataset_truncation(tok_name):
|
||||
|
@ -589,7 +599,7 @@ def test_seq2seq_dataset_truncation(tok_name):
|
|||
break # No need to test every batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")])
|
||||
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
|
||||
def test_legacy_dataset_truncation(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
|
|
Loading…
Reference in New Issue