159 lines
5.8 KiB
Python
Executable File
159 lines
5.8 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import itertools
|
|
import operator
|
|
import sys
|
|
from collections import OrderedDict
|
|
|
|
from run_eval import datetime_now, run_generate
|
|
|
|
from utils import ROUGE_KEYS
|
|
|
|
|
|
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
|
|
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
|
|
task_score_names = {
|
|
"translation": ["bleu"],
|
|
"summarization": ROUGE_KEYS,
|
|
}
|
|
|
|
|
|
def parse_search_arg(search):
|
|
groups = search.split()
|
|
entries = dict((g.split("=") for g in groups))
|
|
entry_names = list(entries.keys())
|
|
sets = [[f"--{k} {v}" for v in vs.split(":")] for k, vs in entries.items()]
|
|
matrix = [list(x) for x in itertools.product(*sets)]
|
|
return matrix, entry_names
|
|
|
|
|
|
def run_search():
|
|
"""
|
|
Run parametric search over the desired hparam space with help of ``run_eval.py``.
|
|
|
|
All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args.
|
|
|
|
The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.:
|
|
```
|
|
--search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
|
|
```
|
|
which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with:
|
|
|
|
```
|
|
--num_beams 5 --length_penalty 0.8 --early_stopping true
|
|
--num_beams 5 --length_penalty 0.8 --early_stopping false
|
|
[...]
|
|
--num_beams 10 --length_penalty 1.2 --early_stopping false
|
|
```
|
|
|
|
On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
|
|
|
|
|
|
"""
|
|
prog = sys.argv[0]
|
|
|
|
parser = argparse.ArgumentParser(
|
|
usage=(
|
|
"\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore"
|
|
" refer to `run_eval.py -h` for the complete list."
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
"--search",
|
|
type=str,
|
|
required=False,
|
|
help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"',
|
|
)
|
|
parser.add_argument(
|
|
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
|
|
)
|
|
parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
|
|
parser.add_argument(
|
|
"--info",
|
|
nargs="?",
|
|
type=str,
|
|
const=datetime_now(),
|
|
help=(
|
|
"add custom notes to be printed before the results table. If no value is passed, the current datetime"
|
|
" string will be used."
|
|
),
|
|
)
|
|
args, args_main = parser.parse_known_args()
|
|
# we share some of the args
|
|
args_main.extend(["--task", args.task])
|
|
args_normal = [prog] + args_main
|
|
|
|
# to support variations like translation_en_to_de"
|
|
task = "translation" if "translation" in args.task else "summarization"
|
|
|
|
matrix, col_names = parse_search_arg(args.search)
|
|
col_names[0:0] = task_score_names[task] # score cols first
|
|
col_widths = {col: len(str(col)) for col in col_names}
|
|
results = []
|
|
for r in matrix:
|
|
hparams = dict((x.replace("--", "").split() for x in r))
|
|
args_exp = " ".join(r).split()
|
|
args_exp.extend(["--bs", str(args.bs)]) # in case we need to reduce its size due to CUDA OOM
|
|
sys.argv = args_normal + args_exp
|
|
|
|
# XXX: need to trap CUDA OOM and lower args.bs if that happens and retry
|
|
|
|
scores = run_generate(verbose=False)
|
|
# make sure scores are first in the table
|
|
result = OrderedDict()
|
|
for score in task_score_names[task]:
|
|
result[score] = scores[score]
|
|
result.update(hparams)
|
|
results.append(result)
|
|
|
|
# find widest entries
|
|
for k, v in result.items():
|
|
l = len(str(v))
|
|
if l > col_widths[k]:
|
|
col_widths[k] = l
|
|
|
|
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
|
|
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
|
|
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
|
|
for row in results_sorted:
|
|
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
|
|
|
|
best = results_sorted[0]
|
|
for score in task_score_names[task]:
|
|
del best[score]
|
|
best_args = [f"--{k} {v}" for k, v in best.items()]
|
|
dyn_args = ["--bs", str(args.bs)]
|
|
if args.info:
|
|
print(f"\nInfo: {args.info}")
|
|
print("\nBest score args:")
|
|
print(" ".join(args_main + best_args + dyn_args))
|
|
|
|
return results_sorted
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Usage:
|
|
# [normal-run_eval_search.py cmd plus] \
|
|
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
|
|
#
|
|
# Example:
|
|
# PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval_search.py $MODEL_NAME \
|
|
# $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target \
|
|
# --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \
|
|
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
|
|
run_search()
|