fix examples/rag imports, tests (#7712)

This commit is contained in:
Sam Shleifer 2020-10-14 11:35:00 -04:00 committed by GitHub
parent 890e790e16
commit 8feb0cc967
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 18 deletions

View File

@ -65,26 +65,41 @@ Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greates
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
```bash
wget https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz && gzip -d biencoder-nq-dev.json.gz
```
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
```bash
mkdir output # or wherever you want to save this
python examples/rag/parse_dpr_relevance_data.py \
--src_path path/to/unziped/biencoder-nq-dev.json \
--evaluation_set path/to/output/biencoder-nq-dev.questions \
--gold_data_path path/to/output/biencoder-nq-dev.pages
--src_path biencoder-nq-dev.json \
--evaluation_set output/biencoder-nq-dev.questions \
--gold_data_path output/biencoder-nq-dev.pages
```
3. Run evaluation:
```bash
```bash
python examples/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \
--model_type rag_sequence \
--evaluation_set output/biencoder-nq-dev.questions \
--gold_data_path output/biencoder-nq-dev.pages \
--predictions_path output/retrieval_preds.tsv \
--eval_mode retrieval \
--k 1
```
```bash
# EXPLANATION
python examples/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
--evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation
--gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
--predictions_path path/to/retrieval_preds.tsv \ # name of file where predictions will be stored
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
--gold_data_path poutput/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
--predictions_path output/retrieval_preds.tsv \ # name of file where predictions will be stored
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
--k 1 # parameter k for the precision@k metric
```
## End-to-end evaluation
We support two formats of the gold data file (controlled by the `gold_data_mode` parameter):
@ -97,7 +112,9 @@ who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiul
Xiu Li Dai
```
Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter. If this path already exists, the script will use saved predictions to calculate metrics. Add `--recalculate` parameter to force the script to perform inference from scratch.
Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter.
If this path already exists, the script will use saved predictions to calculate metrics.
Add `--recalculate` parameter to force the script to perform inference from scratch.
An example e2e evaluation run could look as follows:
```bash

View File

@ -0,0 +1,5 @@
import os
import sys
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))

View File

@ -15,7 +15,7 @@ from transformers import logging as transformers_logging
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
from examples.rag.utils import exact_match_score, f1_score # noqa: E402 # isort:skip
from utils import exact_match_score, f1_score # noqa: E402 # isort:skip
logger = logging.getLogger(__name__)

View File

@ -31,16 +31,13 @@ from transformers import (
from transformers import logging as transformers_logging
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
from examples.rag.callbacks import ( # noqa: E402 # isort:skip
from callbacks import ( # noqa: E402 # isort:skipq
get_checkpoint_callback,
get_early_stopping_callback,
Seq2SeqLoggingCallback,
)
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from examples.rag.utils import ( # noqa: E402 # isort:skip
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from utils import ( # noqa: E402 # isort:skip
calculate_exact_match,
flatten_list,
get_git_info,
@ -53,6 +50,11 @@ from examples.rag.utils import ( # noqa: E402 # isort:skip
Seq2SeqDataset,
)
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

View File

@ -23,7 +23,7 @@ from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FI
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
def require_distributed_retrieval(test_case):