[s2s] run_eval.py parses generate_kwargs (#6948)

This commit is contained in:
Sam Shleifer 2020-09-04 14:19:31 -04:00 committed by GitHub
parent 6078b12098
commit a4fc0c80b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 24 deletions

View File

@ -15,9 +15,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger = getLogger(__name__)
try:
from .utils import calculate_bleu, calculate_rouge, use_task_specific_params
from .utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
except ImportError:
from utils import calculate_bleu, calculate_rouge, use_task_specific_params
from utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@ -36,7 +36,6 @@ def generate_summaries_or_translations(
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
decoder_start_token_id=None,
**generate_kwargs,
) -> Dict:
"""Save model.generate results to <out_file>, and return how long it took."""
@ -59,7 +58,6 @@ def generate_summaries_or_translations(
summaries = model.generate(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
decoder_start_token_id=decoder_start_token_id,
**generate_kwargs,
)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
@ -77,30 +75,20 @@ def run_generate():
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("save_path", type=str, help="where to save summaries")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument(
"--score_path",
type=str,
required=False,
default="metrics.json",
help="where to save the rouge score in json format",
)
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--decoder_start_token_id",
type=int,
default=None,
required=False,
help="Defaults to using config",
)
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()
parsed = parse_numeric_cl_kwargs(rest)
if parsed:
print(f"parsed the following generate kwargs: {parsed}")
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
if args.n_obs > 0:
examples = examples[: args.n_obs]
@ -115,7 +103,7 @@ def run_generate():
device=args.device,
fp16=args.fp16,
task=args.task,
decoder_start_token_id=args.decoder_start_token_id,
**parsed,
)
if args.reference_path is None:
return

View File

@ -300,6 +300,10 @@ def test_run_eval(model):
score_path,
"--task",
task,
"--num_beams",
"2",
"--length_penalty",
"2.0",
]
with patch.object(sys, "argv", testargs):
run_generate()

View File

@ -5,7 +5,7 @@ import os
import pickle
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List
from typing import Callable, Dict, Iterable, List, Union
import git
import numpy as np
@ -309,3 +309,23 @@ def assert_not_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
npars = len(model_grads)
assert any(model_grads), f"none of {npars} weights require grad"
# CLI Parsing utils
def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]:
"""Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
result = {}
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
num_pairs = len(unparsed_args) // 2
for pair_num in range(num_pairs):
i = 2 * pair_num
assert unparsed_args[i].startswith("--")
try:
value = int(unparsed_args[i + 1])
except ValueError:
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
result[unparsed_args[i][2:]] = value
return result