176 lines
6.8 KiB
Python
176 lines
6.8 KiB
Python
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import List, Optional
|
|
|
|
import faiss
|
|
import torch
|
|
from datasets import Features, Sequence, Value, load_dataset
|
|
|
|
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast, HfArgumentParser
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
torch.set_grad_enabled(False)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def split_text(text: str, n=100, character=" ") -> List[str]:
|
|
"""Split the text every ``n``-th occurrence of ``character``"""
|
|
text = text.split(character)
|
|
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
|
|
|
|
|
|
def split_documents(documents: dict) -> dict:
|
|
"""Split documents into passages"""
|
|
titles, texts = [], []
|
|
for title, text in zip(documents["title"], documents["text"]):
|
|
if text is not None:
|
|
for passage in split_text(text):
|
|
titles.append(title if title is not None else "")
|
|
texts.append(passage)
|
|
return {"title": titles, "text": texts}
|
|
|
|
|
|
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
|
|
"""Compute the DPR embeddings of document passages"""
|
|
input_ids = ctx_tokenizer(
|
|
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
|
|
)["input_ids"]
|
|
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
|
|
return {"embeddings": embeddings.detach().cpu().numpy()}
|
|
|
|
|
|
def main(
|
|
rag_example_args: "RagExampleArguments",
|
|
processing_args: "ProcessingArguments",
|
|
index_hnsw_args: "IndexHnswArguments",
|
|
):
|
|
######################################
|
|
logger.info("Step 1 - Create the dataset")
|
|
######################################
|
|
|
|
# The dataset needed for RAG must have three columns:
|
|
# - title (string): title of the document
|
|
# - text (string): text of a passage of the document
|
|
# - embeddings (array of dimension d): DPR representation of the passage
|
|
# Let's say you have documents in tab-separated csv files with columns "title" and "text"
|
|
assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"
|
|
|
|
# You can load a Dataset object this way
|
|
dataset = load_dataset(
|
|
"csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
|
|
)
|
|
|
|
# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets?highlight=csv#csv-files
|
|
|
|
# Then split the documents into passages of 100 words
|
|
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)
|
|
|
|
# And compute the embeddings
|
|
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
|
|
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
|
|
new_features = Features(
|
|
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
|
|
) # optional, save as float32 instead of float64 to save space
|
|
dataset = dataset.map(
|
|
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
|
|
batched=True,
|
|
batch_size=processing_args.batch_size,
|
|
features=new_features,
|
|
)
|
|
|
|
# And finally save your dataset
|
|
passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
|
|
dataset.save_to_disk(passages_path)
|
|
# from datasets import load_from_disk
|
|
# dataset = load_from_disk(passages_path) # to reload the dataset
|
|
|
|
######################################
|
|
logger.info("Step 2 - Index the dataset")
|
|
######################################
|
|
|
|
# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
|
|
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
|
|
dataset.add_faiss_index("embeddings", custom_index=index)
|
|
|
|
# And save the index
|
|
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
|
|
dataset.get_index("embeddings").save(index_path)
|
|
# dataset.load_faiss_index("embeddings", index_path) # to reload the index
|
|
|
|
|
|
@dataclass
|
|
class RagExampleArguments:
|
|
csv_path: str = field(
|
|
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset.csv"),
|
|
metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
|
|
)
|
|
question: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
|
|
)
|
|
rag_model_name: str = field(
|
|
default="facebook/rag-sequence-nq",
|
|
metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
|
|
)
|
|
dpr_ctx_encoder_model_name: str = field(
|
|
default="facebook/dpr-ctx_encoder-multiset-base",
|
|
metadata={
|
|
"help": (
|
|
"The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or"
|
|
" 'facebook/dpr-ctx_encoder-multiset-base'"
|
|
)
|
|
},
|
|
)
|
|
output_dir: Optional[str] = field(
|
|
default=str(Path(__file__).parent / "test_run" / "dummy-kb"),
|
|
metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ProcessingArguments:
|
|
num_proc: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The number of processes to use to split the documents into passages. Default is single process."
|
|
},
|
|
)
|
|
batch_size: int = field(
|
|
default=16,
|
|
metadata={
|
|
"help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class IndexHnswArguments:
|
|
d: int = field(
|
|
default=768,
|
|
metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
|
|
)
|
|
m: int = field(
|
|
default=128,
|
|
metadata={
|
|
"help": (
|
|
"The number of bi-directional links created for every new element during the HNSW index construction."
|
|
)
|
|
},
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.WARNING)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
|
|
rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()
|
|
with TemporaryDirectory() as tmp_dir:
|
|
rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir
|
|
main(rag_example_args, processing_args, index_hnsw_args)
|