[s2s] run_eval.py parses generate_kwargs (#6948)
This commit is contained in:
parent
6078b12098
commit
a4fc0c80b1
|
@ -15,9 +15,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .utils import calculate_bleu, calculate_rouge, use_task_specific_params
|
from .utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from utils import calculate_bleu, calculate_rouge, use_task_specific_params
|
from utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
|
||||||
|
|
||||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
@ -36,7 +36,6 @@ def generate_summaries_or_translations(
|
||||||
device: str = DEFAULT_DEVICE,
|
device: str = DEFAULT_DEVICE,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
task="summarization",
|
task="summarization",
|
||||||
decoder_start_token_id=None,
|
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Save model.generate results to <out_file>, and return how long it took."""
|
"""Save model.generate results to <out_file>, and return how long it took."""
|
||||||
|
@ -59,7 +58,6 @@ def generate_summaries_or_translations(
|
||||||
summaries = model.generate(
|
summaries = model.generate(
|
||||||
input_ids=batch.input_ids,
|
input_ids=batch.input_ids,
|
||||||
attention_mask=batch.attention_mask,
|
attention_mask=batch.attention_mask,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
@ -77,30 +75,20 @@ def run_generate():
|
||||||
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
||||||
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
|
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
|
||||||
parser.add_argument("save_path", type=str, help="where to save summaries")
|
parser.add_argument("save_path", type=str, help="where to save summaries")
|
||||||
|
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
|
||||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
|
||||||
parser.add_argument(
|
|
||||||
"--score_path",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
default="metrics.json",
|
|
||||||
help="where to save the rouge score in json format",
|
|
||||||
)
|
|
||||||
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
||||||
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
|
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
|
||||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||||
parser.add_argument(
|
|
||||||
"--decoder_start_token_id",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
required=False,
|
|
||||||
help="Defaults to using config",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
|
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
|
||||||
)
|
)
|
||||||
parser.add_argument("--fp16", action="store_true")
|
parser.add_argument("--fp16", action="store_true")
|
||||||
args = parser.parse_args()
|
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
|
||||||
|
args, rest = parser.parse_known_args()
|
||||||
|
parsed = parse_numeric_cl_kwargs(rest)
|
||||||
|
if parsed:
|
||||||
|
print(f"parsed the following generate kwargs: {parsed}")
|
||||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||||
if args.n_obs > 0:
|
if args.n_obs > 0:
|
||||||
examples = examples[: args.n_obs]
|
examples = examples[: args.n_obs]
|
||||||
|
@ -115,7 +103,7 @@ def run_generate():
|
||||||
device=args.device,
|
device=args.device,
|
||||||
fp16=args.fp16,
|
fp16=args.fp16,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
decoder_start_token_id=args.decoder_start_token_id,
|
**parsed,
|
||||||
)
|
)
|
||||||
if args.reference_path is None:
|
if args.reference_path is None:
|
||||||
return
|
return
|
||||||
|
|
|
@ -300,6 +300,10 @@ def test_run_eval(model):
|
||||||
score_path,
|
score_path,
|
||||||
"--task",
|
"--task",
|
||||||
task,
|
task,
|
||||||
|
"--num_beams",
|
||||||
|
"2",
|
||||||
|
"--length_penalty",
|
||||||
|
"2.0",
|
||||||
]
|
]
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_generate()
|
run_generate()
|
||||||
|
|
|
@ -5,7 +5,7 @@ import os
|
||||||
import pickle
|
import pickle
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, List
|
from typing import Callable, Dict, Iterable, List, Union
|
||||||
|
|
||||||
import git
|
import git
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -309,3 +309,23 @@ def assert_not_all_frozen(model):
|
||||||
model_grads: List[bool] = list(grad_status(model))
|
model_grads: List[bool] = list(grad_status(model))
|
||||||
npars = len(model_grads)
|
npars = len(model_grads)
|
||||||
assert any(model_grads), f"none of {npars} weights require grad"
|
assert any(model_grads), f"none of {npars} weights require grad"
|
||||||
|
|
||||||
|
|
||||||
|
# CLI Parsing utils
|
||||||
|
|
||||||
|
|
||||||
|
def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]:
|
||||||
|
"""Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
|
||||||
|
result = {}
|
||||||
|
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
|
||||||
|
num_pairs = len(unparsed_args) // 2
|
||||||
|
for pair_num in range(num_pairs):
|
||||||
|
i = 2 * pair_num
|
||||||
|
assert unparsed_args[i].startswith("--")
|
||||||
|
try:
|
||||||
|
value = int(unparsed_args[i + 1])
|
||||||
|
except ValueError:
|
||||||
|
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
|
||||||
|
|
||||||
|
result[unparsed_args[i][2:]] = value
|
||||||
|
return result
|
||||||
|
|
Loading…
Reference in New Issue