fix examples/rag imports, tests (#7712)
This commit is contained in:
parent
890e790e16
commit
8feb0cc967
|
@ -65,26 +65,41 @@ Does He Love You Does He Love You Red Sandy Spika dress of Reba McEntire Greates
|
|||
We demonstrate how to evaluate retrieval against DPR evaluation data. You can download respective files from links listed [here](https://github.com/facebookresearch/DPR/blob/master/data/download_data.py#L39-L45).
|
||||
|
||||
1. Download and unzip the gold data file. We use the `biencoder-nq-dev` from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz.
|
||||
```bash
|
||||
wget https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz && gzip -d biencoder-nq-dev.json.gz
|
||||
```
|
||||
|
||||
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
||||
```bash
|
||||
mkdir output # or wherever you want to save this
|
||||
python examples/rag/parse_dpr_relevance_data.py \
|
||||
--src_path path/to/unziped/biencoder-nq-dev.json \
|
||||
--evaluation_set path/to/output/biencoder-nq-dev.questions \
|
||||
--gold_data_path path/to/output/biencoder-nq-dev.pages
|
||||
--src_path biencoder-nq-dev.json \
|
||||
--evaluation_set output/biencoder-nq-dev.questions \
|
||||
--gold_data_path output/biencoder-nq-dev.pages
|
||||
```
|
||||
3. Run evaluation:
|
||||
```bash
|
||||
```bash
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \
|
||||
--model_type rag_sequence \
|
||||
--evaluation_set output/biencoder-nq-dev.questions \
|
||||
--gold_data_path output/biencoder-nq-dev.pages \
|
||||
--predictions_path output/retrieval_preds.tsv \
|
||||
--eval_mode retrieval \
|
||||
--k 1
|
||||
```
|
||||
```bash
|
||||
# EXPLANATION
|
||||
python examples/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
|
||||
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
||||
--evaluation_set path/to/output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
||||
--gold_data_path path/to/output/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
|
||||
--predictions_path path/to/retrieval_preds.tsv \ # name of file where predictions will be stored
|
||||
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
||||
--gold_data_path poutput/biencoder-nq-dev.pages \ # a dataset containing ground truth answers for samples from the evaluation_set
|
||||
--predictions_path output/retrieval_preds.tsv \ # name of file where predictions will be stored
|
||||
--eval_mode retrieval \ # indicates whether we're performing retrieval evaluation or e2e evaluation
|
||||
--k 1 # parameter k for the precision@k metric
|
||||
|
||||
```
|
||||
|
||||
|
||||
## End-to-end evaluation
|
||||
|
||||
We support two formats of the gold data file (controlled by the `gold_data_mode` parameter):
|
||||
|
@ -97,7 +112,9 @@ who is the owner of reading football club ['Xiu Li Dai', 'Dai Yongge', 'Dai Xiul
|
|||
Xiu Li Dai
|
||||
```
|
||||
|
||||
Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter. If this path already exists, the script will use saved predictions to calculate metrics. Add `--recalculate` parameter to force the script to perform inference from scratch.
|
||||
Predictions of the model for the samples from the `evaluation_set` will be saved under the path specified by the `predictions_path` parameter.
|
||||
If this path already exists, the script will use saved predictions to calculate metrics.
|
||||
Add `--recalculate` parameter to force the script to perform inference from scratch.
|
||||
|
||||
An example e2e evaluation run could look as follows:
|
||||
```bash
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))
|
|
@ -15,7 +15,7 @@ from transformers import logging as transformers_logging
|
|||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
|
||||
from examples.rag.utils import exact_match_score, f1_score # noqa: E402 # isort:skip
|
||||
from utils import exact_match_score, f1_score # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -31,16 +31,13 @@ from transformers import (
|
|||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
|
||||
from examples.rag.callbacks import ( # noqa: E402 # isort:skip
|
||||
from callbacks import ( # noqa: E402 # isort:skipq
|
||||
get_checkpoint_callback,
|
||||
get_early_stopping_callback,
|
||||
Seq2SeqLoggingCallback,
|
||||
)
|
||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from examples.rag.utils import ( # noqa: E402 # isort:skip
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from utils import ( # noqa: E402 # isort:skip
|
||||
calculate_exact_match,
|
||||
flatten_list,
|
||||
get_git_info,
|
||||
|
@ -53,6 +50,11 @@ from examples.rag.utils import ( # noqa: E402 # isort:skip
|
|||
Seq2SeqDataset,
|
||||
)
|
||||
|
||||
# need the parent dir module
|
||||
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FI
|
|||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
def require_distributed_retrieval(test_case):
|
||||
|
|
Loading…
Reference in New Issue