81 lines
3.1 KiB
Python
81 lines
3.1 KiB
Python
import os
|
|
from functools import partial
|
|
from glob import glob
|
|
|
|
import faiss
|
|
from datasets import Features, Sequence, Value, concatenate_datasets, load_dataset, load_from_disk
|
|
|
|
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast
|
|
|
|
|
|
def split_text(text, n=100, character=" "):
|
|
"""Split the text every ``n``-th occurrence of ``character``"""
|
|
text = text.split(character)
|
|
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
|
|
|
|
|
|
def split_documents(documents):
|
|
"""Split documents into passages"""
|
|
titles, texts = [], []
|
|
for title, text in zip(documents["title"], documents["text"]):
|
|
if text is not None:
|
|
for passage in split_text(text):
|
|
titles.append(title if title is not None else "")
|
|
texts.append(passage)
|
|
return {"title": titles, "text": texts}
|
|
|
|
|
|
def embed_update(ctx_encoder, total_processes, device, process_num, shard_dir, csv_path):
|
|
kb_dataset = load_dataset(
|
|
"csv", data_files=[csv_path], split="train", delimiter="\t", column_names=["title", "text"]
|
|
)
|
|
kb_dataset = kb_dataset.map(
|
|
split_documents, batched=True, num_proc=1
|
|
) # if you want you can load already splitted csv.
|
|
kb_list = [kb_dataset.shard(total_processes, i, contiguous=True) for i in range(total_processes)]
|
|
data_shrad = kb_list[process_num]
|
|
|
|
arrow_folder = "data_" + str(process_num)
|
|
passages_path = os.path.join(shard_dir, arrow_folder)
|
|
|
|
context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
|
|
ctx_encoder = ctx_encoder.to(device=device)
|
|
|
|
def embed(
|
|
documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast, device
|
|
) -> dict:
|
|
"""Compute the DPR embeddings of document passages"""
|
|
input_ids = ctx_tokenizer(
|
|
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
|
|
)["input_ids"]
|
|
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
|
|
return {"embeddings": embeddings.detach().cpu().numpy()}
|
|
|
|
new_features = Features(
|
|
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
|
|
) # optional, save as float32 instead of float64 to save space
|
|
|
|
dataset = data_shrad.map(
|
|
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=context_tokenizer, device=device),
|
|
batched=True,
|
|
batch_size=16,
|
|
features=new_features,
|
|
)
|
|
dataset.save_to_disk(passages_path)
|
|
|
|
|
|
def add_index(shard_dir, index_path):
|
|
data_shard_list = []
|
|
|
|
for shard_address in glob(str(shard_dir) + "/*/"):
|
|
data_shard_list.append(load_from_disk(shard_address))
|
|
|
|
concat = concatenate_datasets(data_shard_list)
|
|
faiss.omp_set_num_threads(96)
|
|
|
|
index = faiss.IndexHNSWFlat(768, 128, faiss.METRIC_INNER_PRODUCT)
|
|
concat.add_faiss_index("embeddings", custom_index=index)
|
|
concat.get_index("embeddings").save(
|
|
index_path
|
|
) # since we load the index in to memory,we can directly update the index in the disk
|