migrate to fastembed

This commit is contained in:
generall 2023-10-17 00:06:08 +02:00
parent fd3ab3d3aa
commit d3d4efdab7
8 changed files with 573 additions and 1309 deletions

View File

@ -36,7 +36,7 @@ COPY ./poetry.lock /app
COPY --from=build-step /app/dist /app/static
RUN poetry install --no-interaction --no-ansi --no-root --without dev
RUN python -c 'from sentence_transformers import SentenceTransformer; SentenceTransformer("all-MiniLM-L6-v2") '
RUN python -c 'from fastembed.embedding import DefaultEmbedding; DefaultEmbedding("BAAI/bge-small-en")'
# Finally copy the application source code and install root
COPY qdrant_demo /app/qdrant_demo

View File

@ -128,7 +128,7 @@ export function Main() {
name={item.name}
images={item.logo_url}
alt={item.name}
description={item.short_description}
description={item.document}
link={item.homepage_url}
city={
item.city ??

1746
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -8,19 +8,12 @@ authors = ["Andrey Vasnetsov <andrey@vasnetsov.com>"]
python = "~3.11"
fastapi = "^0.103.1"
uvicorn = "^0.18.3"
sentence-transformers = "^2.2.0"
psutil = "^5.7.3"
nltk = "^3.7"
pandas = "^1.1.5"
loguru = "^0.5.3"
requests = "^2.25.1"
qdrant-client = "^1.5.4"
torch = [
{url="https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl", markers = "python_version == '3.11' and sys_platform == 'linux' and platform_machine == 'x86_64'"},
{url="https://download.pytorch.org/whl/cpu/torch-2.0.1-cp311-none-macosx_10_9_x86_64.whl", markers = "python_version == '3.11' and sys_platform == 'darwin' and platform_machine == 'x86_64'"},
{url="https://download.pytorch.org/whl/cpu/torch-2.0.1-cp311-none-macosx_11_0_arm64.whl", markers = "python_version == '3.11' and sys_platform == 'darwin' and platform_machine == 'arm64'"},
]
tqdm = "^4.66.1"
qdrant-client = {extras = ["fastembed"], version = "^1.6.3"}
[tool.poetry.dev-dependencies]

View File

@ -9,5 +9,7 @@ QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333/")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "text-demo")
EMBEDDINGS_MODEL = os.environ.get("EMBEDDINGS_MODEL", "BAAI/bge-small-en")
VECTOR_FIELD_NAME = "fast-bge-small-en"
TEXT_FIELD_NAME = "short_description"

View File

@ -1,55 +1,13 @@
import os.path
import numpy as np
import pandas as pd
from qdrant_client import QdrantClient, models
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, VECTOR_FIELD_NAME
# Define the CSV file path and NPY file path
csv_file_path = os.path.join(DATA_DIR, "organizations.csv")
npy_file_path = os.path.join(DATA_DIR, "embeddings.npy")
def generate_embeddings():
# Load the SentenceTransformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Define a function to calculate embeddings
def calculate_embeddings(texts):
embeddings = model.encode(texts, show_progress_bar=False)
return embeddings
# Load the CSV file into a DataFrame
df = pd.read_csv(csv_file_path)
# Handle missing or non-string values in the TEXT_FIELD_NAME column
df[TEXT_FIELD_NAME] = df[TEXT_FIELD_NAME].fillna('') # Replace NaN with empty string
df[TEXT_FIELD_NAME] = df[TEXT_FIELD_NAME].astype(str) # Ensure all values are strings
# Split the data into chunks to save RAM
batch_size = 1000
num_chunks = len(df) // batch_size + 1
embeddings_list = []
# Iterate over chunks and calculate embeddings
for i in tqdm(range(num_chunks), desc="Calculating Embeddings"):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
batch_texts = df[TEXT_FIELD_NAME].iloc[start_idx:end_idx].tolist()
batch_embeddings = calculate_embeddings(batch_texts)
embeddings_list.extend(batch_embeddings)
# Convert embeddings list to a numpy array
embeddings_array = np.array(embeddings_list)
# Save the embeddings to an NPY file
np.save(npy_file_path, embeddings_array)
print(f"Embeddings saved to {npy_file_path}")
def upload_embeddings():
@ -58,21 +16,14 @@ def upload_embeddings():
api_key=QDRANT_API_KEY,
)
df = pd.read_csv(csv_file_path)
payload = df.to_dict('records')
vectors = np.load(npy_file_path)
df = pd.read_csv(csv_file_path, nrows=1000)
documents = df['short_description'].tolist()
df.drop(columns=['short_description'], inplace=True)
metadata = df.to_dict('records')
client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config={
VECTOR_FIELD_NAME: models.VectorParams(
size=vectors.shape[1],
distance=models.Distance.COSINE,
on_disk=True,
)
},
vectors_config=client.get_fastembed_vector_params(on_disk=True),
# Quantization is optional, but it can significantly reduce the memory usage
quantization_config=models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
@ -97,17 +48,14 @@ def upload_embeddings():
)
)
client.upload_collection(
client.add(
collection_name=COLLECTION_NAME,
vectors={
VECTOR_FIELD_NAME: vectors
},
payload=payload,
ids=None, # Vector ids will be assigned automatically
batch_size=256 # How many vectors will be uploaded in a single request?
documents=documents,
metadata=metadata,
ids=tqdm(range(len(documents))),
parallel=0,
)
if __name__ == '__main__':
generate_embeddings()
upload_embeddings()

View File

@ -1,50 +1,37 @@
import json
import os.path
import numpy as np
from qdrant_client import QdrantClient, models
from qdrant_client.qdrant_fastembed import SUPPORTED_EMBEDDING_MODELS
from tqdm import tqdm
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, VECTOR_FIELD_NAME
# Define the CSV file path and NPY file path
csv_file_path = os.path.join(DATA_DIR, "organizations.csv")
npy_file_path = os.path.join(DATA_DIR, "embeddings.npy")
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, \
VECTOR_FIELD_NAME, EMBEDDINGS_MODEL
def upload_embeddings():
client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
prefer_grpc=True,
)
vectors_path = os.path.join(DATA_DIR, 'startup_vectors.npy')
vectors = np.load(vectors_path)
vector_size = vectors.shape[1]
payload_path = os.path.join(DATA_DIR, 'startups.json')
payload = []
documents = []
with open(payload_path) as fd:
for line in fd:
obj = json.loads(line)
# Rename fields to unified schema
obj[TEXT_FIELD_NAME] = obj.pop('description')
documents.append(obj.pop('description'))
obj["logo_url"] = obj.pop("images")
obj["homepage_url"] = obj.pop("link")
payload.append(obj)
client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config={
VECTOR_FIELD_NAME: models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE,
on_disk=True,
)
},
vectors_config=client.get_fastembed_vector_params(on_disk=True),
# Quantization is optional, but it can significantly reduce the memory usage
quantization_config=models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
@ -69,15 +56,12 @@ def upload_embeddings():
)
)
client.upload_collection(
client.add(
collection_name=COLLECTION_NAME,
vectors={
VECTOR_FIELD_NAME: vectors
},
payload=payload,
ids=None, # Vector ids will be assigned automatically
batch_size=64, # How many vectors will be uploaded in a single request?
parallel=10,
documents=documents,
metadata=payload,
ids=tqdm(range(len(payload))),
parallel=0,
)

View File

@ -2,24 +2,21 @@ from typing import List
from qdrant_client import QdrantClient
from qdrant_client.http.models.models import Filter
from sentence_transformers import SentenceTransformer
from qdrant_demo.config import QDRANT_URL, QDRANT_API_KEY, VECTOR_FIELD_NAME
from qdrant_demo.config import QDRANT_URL, QDRANT_API_KEY
class NeuralSearcher:
def __init__(self, collection_name: str):
self.collection_name = collection_name
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
def search(self, text: str, filter_: dict = None) -> List[dict]:
vector = self.model.encode(text).tolist()
hits = self.qdrant_client.search(
hits = self.qdrant_client.query(
collection_name=self.collection_name,
query_vector=(VECTOR_FIELD_NAME, vector),
query_text=text,
query_filter=Filter(**filter_) if filter_ else None,
limit=5
)
return [hit.payload for hit in hits]
return [hit.metadata for hit in hits]