321 lines
11 KiB
Python
321 lines
11 KiB
Python
""" Evaluation script for RAG models."""
|
|
|
|
import argparse
|
|
import ast
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
import pandas as pd
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration
|
|
from transformers import logging as transformers_logging
|
|
|
|
|
|
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
|
|
from utils_rag import exact_match_score, f1_score # noqa: E402 # isort:skip
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
transformers_logging.set_verbosity_info()
|
|
|
|
|
|
def infer_model_type(model_name_or_path):
|
|
if "token" in model_name_or_path:
|
|
return "rag_token"
|
|
if "sequence" in model_name_or_path:
|
|
return "rag_sequence"
|
|
if "bart" in model_name_or_path:
|
|
return "bart"
|
|
return None
|
|
|
|
|
|
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
|
return max(metric_fn(prediction, gt) for gt in ground_truths)
|
|
|
|
|
|
def get_scores(args, preds_path, gold_data_path):
|
|
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
|
answers = []
|
|
|
|
if args.gold_data_mode == "qa":
|
|
data = pd.read_csv(gold_data_path, sep="\t", header=None)
|
|
for answer_list in data[1]:
|
|
ground_truths = ast.literal_eval(answer_list)
|
|
answers.append(ground_truths)
|
|
else:
|
|
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
|
answers = [[reference] for reference in references]
|
|
|
|
f1 = em = total = 0
|
|
for prediction, ground_truths in zip(hypos, answers):
|
|
total += 1
|
|
em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
|
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
|
|
|
em = 100.0 * em / total
|
|
f1 = 100.0 * f1 / total
|
|
|
|
logger.info(f"F1: {f1:.2f}")
|
|
logger.info(f"EM: {em:.2f}")
|
|
|
|
|
|
def get_precision_at_k(args, preds_path, gold_data_path):
|
|
k = args.k
|
|
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
|
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
|
|
|
em = total = 0
|
|
for hypo, reference in zip(hypos, references):
|
|
hypo_provenance = set(hypo.split("\t")[:k])
|
|
ref_provenance = set(reference.split("\t"))
|
|
total += 1
|
|
em += len(hypo_provenance & ref_provenance) / k
|
|
|
|
em = 100.0 * em / total
|
|
logger.info(f"Precision@{k}: {em: .2f}")
|
|
|
|
|
|
def evaluate_batch_retrieval(args, rag_model, questions):
|
|
def strip_title(title):
|
|
if title.startswith('"'):
|
|
title = title[1:]
|
|
if title.endswith('"'):
|
|
title = title[:-1]
|
|
return title
|
|
|
|
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
|
questions,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
)["input_ids"].to(args.device)
|
|
|
|
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids)
|
|
question_enc_pool_output = question_enc_outputs[0]
|
|
|
|
result = rag_model.retriever(
|
|
retriever_input_ids,
|
|
question_enc_pool_output.cpu().detach().to(torch.float32).numpy(),
|
|
prefix=rag_model.rag.generator.config.prefix,
|
|
n_docs=rag_model.config.n_docs,
|
|
return_tensors="pt",
|
|
)
|
|
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
|
|
provenance_strings = []
|
|
for docs in all_docs:
|
|
provenance = [strip_title(title) for title in docs["title"]]
|
|
provenance_strings.append("\t".join(provenance))
|
|
return provenance_strings
|
|
|
|
|
|
def evaluate_batch_e2e(args, rag_model, questions):
|
|
with torch.no_grad():
|
|
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
|
questions, return_tensors="pt", padding=True, truncation=True
|
|
)
|
|
|
|
input_ids = inputs_dict.input_ids.to(args.device)
|
|
attention_mask = inputs_dict.attention_mask.to(args.device)
|
|
outputs = rag_model.generate( # rag_model overwrites generate
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
num_beams=args.num_beams,
|
|
min_length=args.min_length,
|
|
max_length=args.max_length,
|
|
early_stopping=False,
|
|
num_return_sequences=1,
|
|
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
|
|
)
|
|
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
|
|
if args.print_predictions:
|
|
for q, a in zip(questions, answers):
|
|
logger.info("Q: {} - A: {}".format(q, a))
|
|
|
|
return answers
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--model_type",
|
|
choices=["rag_sequence", "rag_token", "bart"],
|
|
type=str,
|
|
help=(
|
|
"RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the"
|
|
" model_name_or_path"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--index_name",
|
|
default=None,
|
|
choices=["exact", "compressed", "legacy"],
|
|
type=str,
|
|
help="RAG model retriever type",
|
|
)
|
|
parser.add_argument(
|
|
"--index_path",
|
|
default=None,
|
|
type=str,
|
|
help="Path to the retrieval index",
|
|
)
|
|
parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs")
|
|
parser.add_argument(
|
|
"--model_name_or_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to pretrained checkpoints or model identifier from huggingface.co/models",
|
|
)
|
|
parser.add_argument(
|
|
"--eval_mode",
|
|
choices=["e2e", "retrieval"],
|
|
default="e2e",
|
|
type=str,
|
|
help=(
|
|
"Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates"
|
|
" precision@k."
|
|
),
|
|
)
|
|
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
|
|
parser.add_argument(
|
|
"--evaluation_set",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to a file containing evaluation samples",
|
|
)
|
|
parser.add_argument(
|
|
"--gold_data_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to a tab-separated file with gold samples",
|
|
)
|
|
parser.add_argument(
|
|
"--gold_data_mode",
|
|
default="qa",
|
|
type=str,
|
|
choices=["qa", "ans"],
|
|
help=(
|
|
"Format of the gold data file"
|
|
"qa - a single line in the following format: question [tab] answer_list"
|
|
"ans - a single line of the gold file contains the expected answer string"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--predictions_path",
|
|
type=str,
|
|
default="predictions.txt",
|
|
help="Name of the predictions file, to be stored in the checkpoints directory",
|
|
)
|
|
parser.add_argument(
|
|
"--eval_all_checkpoints",
|
|
action="store_true",
|
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
|
)
|
|
parser.add_argument(
|
|
"--eval_batch_size",
|
|
default=8,
|
|
type=int,
|
|
help="Batch size per GPU/CPU for evaluation.",
|
|
)
|
|
parser.add_argument(
|
|
"--recalculate",
|
|
help="Recalculate predictions even if the prediction file exists",
|
|
action="store_true",
|
|
)
|
|
parser.add_argument(
|
|
"--num_beams",
|
|
default=4,
|
|
type=int,
|
|
help="Number of beams to be used when generating answers",
|
|
)
|
|
parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers")
|
|
parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers")
|
|
|
|
parser.add_argument(
|
|
"--print_predictions",
|
|
action="store_true",
|
|
help="If True, prints predictions while evaluating.",
|
|
)
|
|
parser.add_argument(
|
|
"--print_docs",
|
|
action="store_true",
|
|
help="If True, prints docs retried while generating.",
|
|
)
|
|
args = parser.parse_args()
|
|
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
return args
|
|
|
|
|
|
def main(args):
|
|
model_kwargs = {}
|
|
if args.model_type is None:
|
|
args.model_type = infer_model_type(args.model_name_or_path)
|
|
assert args.model_type is not None
|
|
if args.model_type.startswith("rag"):
|
|
model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration
|
|
model_kwargs["n_docs"] = args.n_docs
|
|
if args.index_name is not None:
|
|
model_kwargs["index_name"] = args.index_name
|
|
if args.index_path is not None:
|
|
model_kwargs["index_path"] = args.index_path
|
|
else:
|
|
model_class = BartForConditionalGeneration
|
|
|
|
checkpoints = (
|
|
[f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()]
|
|
if args.eval_all_checkpoints
|
|
else [args.model_name_or_path]
|
|
)
|
|
|
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
|
|
|
score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k
|
|
evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval
|
|
|
|
for checkpoint in checkpoints:
|
|
if os.path.exists(args.predictions_path) and (not args.recalculate):
|
|
logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path))
|
|
score_fn(args, args.predictions_path, args.gold_data_path)
|
|
continue
|
|
|
|
logger.info("***** Running evaluation for {} *****".format(checkpoint))
|
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
|
logger.info(" Predictions will be stored under {}".format(args.predictions_path))
|
|
|
|
if args.model_type.startswith("rag"):
|
|
retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs)
|
|
model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs)
|
|
model.retriever.init_retrieval()
|
|
else:
|
|
model = model_class.from_pretrained(checkpoint, **model_kwargs)
|
|
model.to(args.device)
|
|
|
|
with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file:
|
|
questions = []
|
|
for line in tqdm(eval_file):
|
|
questions.append(line.strip())
|
|
if len(questions) == args.eval_batch_size:
|
|
answers = evaluate_batch_fn(args, model, questions)
|
|
preds_file.write("\n".join(answers) + "\n")
|
|
preds_file.flush()
|
|
questions = []
|
|
if len(questions) > 0:
|
|
answers = evaluate_batch_fn(args, model, questions)
|
|
preds_file.write("\n".join(answers))
|
|
preds_file.flush()
|
|
|
|
score_fn(args, args.predictions_path, args.gold_data_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
main(args)
|