migrate to fastembed
This commit is contained in:
parent
fd3ab3d3aa
commit
d3d4efdab7
|
@ -36,7 +36,7 @@ COPY ./poetry.lock /app
|
||||||
COPY --from=build-step /app/dist /app/static
|
COPY --from=build-step /app/dist /app/static
|
||||||
|
|
||||||
RUN poetry install --no-interaction --no-ansi --no-root --without dev
|
RUN poetry install --no-interaction --no-ansi --no-root --without dev
|
||||||
RUN python -c 'from sentence_transformers import SentenceTransformer; SentenceTransformer("all-MiniLM-L6-v2") '
|
RUN python -c 'from fastembed.embedding import DefaultEmbedding; DefaultEmbedding("BAAI/bge-small-en")'
|
||||||
|
|
||||||
# Finally copy the application source code and install root
|
# Finally copy the application source code and install root
|
||||||
COPY qdrant_demo /app/qdrant_demo
|
COPY qdrant_demo /app/qdrant_demo
|
||||||
|
|
|
@ -128,7 +128,7 @@ export function Main() {
|
||||||
name={item.name}
|
name={item.name}
|
||||||
images={item.logo_url}
|
images={item.logo_url}
|
||||||
alt={item.name}
|
alt={item.name}
|
||||||
description={item.short_description}
|
description={item.document}
|
||||||
link={item.homepage_url}
|
link={item.homepage_url}
|
||||||
city={
|
city={
|
||||||
item.city ??
|
item.city ??
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -8,19 +8,12 @@ authors = ["Andrey Vasnetsov <andrey@vasnetsov.com>"]
|
||||||
python = "~3.11"
|
python = "~3.11"
|
||||||
fastapi = "^0.103.1"
|
fastapi = "^0.103.1"
|
||||||
uvicorn = "^0.18.3"
|
uvicorn = "^0.18.3"
|
||||||
sentence-transformers = "^2.2.0"
|
|
||||||
psutil = "^5.7.3"
|
psutil = "^5.7.3"
|
||||||
nltk = "^3.7"
|
|
||||||
pandas = "^1.1.5"
|
pandas = "^1.1.5"
|
||||||
loguru = "^0.5.3"
|
loguru = "^0.5.3"
|
||||||
requests = "^2.25.1"
|
requests = "^2.25.1"
|
||||||
qdrant-client = "^1.5.4"
|
|
||||||
torch = [
|
|
||||||
{url="https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.0.1%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl", markers = "python_version == '3.11' and sys_platform == 'linux' and platform_machine == 'x86_64'"},
|
|
||||||
{url="https://download.pytorch.org/whl/cpu/torch-2.0.1-cp311-none-macosx_10_9_x86_64.whl", markers = "python_version == '3.11' and sys_platform == 'darwin' and platform_machine == 'x86_64'"},
|
|
||||||
{url="https://download.pytorch.org/whl/cpu/torch-2.0.1-cp311-none-macosx_11_0_arm64.whl", markers = "python_version == '3.11' and sys_platform == 'darwin' and platform_machine == 'arm64'"},
|
|
||||||
]
|
|
||||||
tqdm = "^4.66.1"
|
tqdm = "^4.66.1"
|
||||||
|
qdrant-client = {extras = ["fastembed"], version = "^1.6.3"}
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
|
|
||||||
|
|
|
@ -9,5 +9,7 @@ QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333/")
|
||||||
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
||||||
|
|
||||||
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "text-demo")
|
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "text-demo")
|
||||||
|
EMBEDDINGS_MODEL = os.environ.get("EMBEDDINGS_MODEL", "BAAI/bge-small-en")
|
||||||
|
|
||||||
VECTOR_FIELD_NAME = "fast-bge-small-en"
|
VECTOR_FIELD_NAME = "fast-bge-small-en"
|
||||||
TEXT_FIELD_NAME = "short_description"
|
TEXT_FIELD_NAME = "short_description"
|
||||||
|
|
|
@ -1,55 +1,13 @@
|
||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from qdrant_client import QdrantClient, models
|
from qdrant_client import QdrantClient, models
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, VECTOR_FIELD_NAME
|
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, VECTOR_FIELD_NAME
|
||||||
|
|
||||||
# Define the CSV file path and NPY file path
|
# Define the CSV file path and NPY file path
|
||||||
csv_file_path = os.path.join(DATA_DIR, "organizations.csv")
|
csv_file_path = os.path.join(DATA_DIR, "organizations.csv")
|
||||||
npy_file_path = os.path.join(DATA_DIR, "embeddings.npy")
|
|
||||||
|
|
||||||
|
|
||||||
def generate_embeddings():
|
|
||||||
# Load the SentenceTransformer model
|
|
||||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
||||||
|
|
||||||
# Define a function to calculate embeddings
|
|
||||||
def calculate_embeddings(texts):
|
|
||||||
embeddings = model.encode(texts, show_progress_bar=False)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
# Load the CSV file into a DataFrame
|
|
||||||
df = pd.read_csv(csv_file_path)
|
|
||||||
|
|
||||||
# Handle missing or non-string values in the TEXT_FIELD_NAME column
|
|
||||||
df[TEXT_FIELD_NAME] = df[TEXT_FIELD_NAME].fillna('') # Replace NaN with empty string
|
|
||||||
df[TEXT_FIELD_NAME] = df[TEXT_FIELD_NAME].astype(str) # Ensure all values are strings
|
|
||||||
|
|
||||||
# Split the data into chunks to save RAM
|
|
||||||
batch_size = 1000
|
|
||||||
num_chunks = len(df) // batch_size + 1
|
|
||||||
|
|
||||||
embeddings_list = []
|
|
||||||
|
|
||||||
# Iterate over chunks and calculate embeddings
|
|
||||||
for i in tqdm(range(num_chunks), desc="Calculating Embeddings"):
|
|
||||||
start_idx = i * batch_size
|
|
||||||
end_idx = (i + 1) * batch_size
|
|
||||||
batch_texts = df[TEXT_FIELD_NAME].iloc[start_idx:end_idx].tolist()
|
|
||||||
batch_embeddings = calculate_embeddings(batch_texts)
|
|
||||||
embeddings_list.extend(batch_embeddings)
|
|
||||||
|
|
||||||
# Convert embeddings list to a numpy array
|
|
||||||
embeddings_array = np.array(embeddings_list)
|
|
||||||
|
|
||||||
# Save the embeddings to an NPY file
|
|
||||||
np.save(npy_file_path, embeddings_array)
|
|
||||||
|
|
||||||
print(f"Embeddings saved to {npy_file_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def upload_embeddings():
|
def upload_embeddings():
|
||||||
|
@ -58,21 +16,14 @@ def upload_embeddings():
|
||||||
api_key=QDRANT_API_KEY,
|
api_key=QDRANT_API_KEY,
|
||||||
)
|
)
|
||||||
|
|
||||||
df = pd.read_csv(csv_file_path)
|
df = pd.read_csv(csv_file_path, nrows=1000)
|
||||||
|
documents = df['short_description'].tolist()
|
||||||
payload = df.to_dict('records')
|
df.drop(columns=['short_description'], inplace=True)
|
||||||
|
metadata = df.to_dict('records')
|
||||||
vectors = np.load(npy_file_path)
|
|
||||||
|
|
||||||
client.recreate_collection(
|
client.recreate_collection(
|
||||||
collection_name=COLLECTION_NAME,
|
collection_name=COLLECTION_NAME,
|
||||||
vectors_config={
|
vectors_config=client.get_fastembed_vector_params(on_disk=True),
|
||||||
VECTOR_FIELD_NAME: models.VectorParams(
|
|
||||||
size=vectors.shape[1],
|
|
||||||
distance=models.Distance.COSINE,
|
|
||||||
on_disk=True,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
# Quantization is optional, but it can significantly reduce the memory usage
|
# Quantization is optional, but it can significantly reduce the memory usage
|
||||||
quantization_config=models.ScalarQuantization(
|
quantization_config=models.ScalarQuantization(
|
||||||
scalar=models.ScalarQuantizationConfig(
|
scalar=models.ScalarQuantizationConfig(
|
||||||
|
@ -97,17 +48,14 @@ def upload_embeddings():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
client.upload_collection(
|
client.add(
|
||||||
collection_name=COLLECTION_NAME,
|
collection_name=COLLECTION_NAME,
|
||||||
vectors={
|
documents=documents,
|
||||||
VECTOR_FIELD_NAME: vectors
|
metadata=metadata,
|
||||||
},
|
ids=tqdm(range(len(documents))),
|
||||||
payload=payload,
|
parallel=0,
|
||||||
ids=None, # Vector ids will be assigned automatically
|
|
||||||
batch_size=256 # How many vectors will be uploaded in a single request?
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
generate_embeddings()
|
|
||||||
upload_embeddings()
|
upload_embeddings()
|
||||||
|
|
|
@ -1,50 +1,37 @@
|
||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from qdrant_client import QdrantClient, models
|
from qdrant_client import QdrantClient, models
|
||||||
|
from qdrant_client.qdrant_fastembed import SUPPORTED_EMBEDDING_MODELS
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, VECTOR_FIELD_NAME
|
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, \
|
||||||
|
VECTOR_FIELD_NAME, EMBEDDINGS_MODEL
|
||||||
# Define the CSV file path and NPY file path
|
|
||||||
csv_file_path = os.path.join(DATA_DIR, "organizations.csv")
|
|
||||||
npy_file_path = os.path.join(DATA_DIR, "embeddings.npy")
|
|
||||||
|
|
||||||
|
|
||||||
def upload_embeddings():
|
def upload_embeddings():
|
||||||
|
|
||||||
client = QdrantClient(
|
client = QdrantClient(
|
||||||
url=QDRANT_URL,
|
url=QDRANT_URL,
|
||||||
api_key=QDRANT_API_KEY,
|
api_key=QDRANT_API_KEY,
|
||||||
prefer_grpc=True,
|
prefer_grpc=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
vectors_path = os.path.join(DATA_DIR, 'startup_vectors.npy')
|
|
||||||
vectors = np.load(vectors_path)
|
|
||||||
vector_size = vectors.shape[1]
|
|
||||||
|
|
||||||
payload_path = os.path.join(DATA_DIR, 'startups.json')
|
payload_path = os.path.join(DATA_DIR, 'startups.json')
|
||||||
payload = []
|
payload = []
|
||||||
|
documents = []
|
||||||
|
|
||||||
with open(payload_path) as fd:
|
with open(payload_path) as fd:
|
||||||
for line in fd:
|
for line in fd:
|
||||||
obj = json.loads(line)
|
obj = json.loads(line)
|
||||||
# Rename fields to unified schema
|
# Rename fields to unified schema
|
||||||
obj[TEXT_FIELD_NAME] = obj.pop('description')
|
documents.append(obj.pop('description'))
|
||||||
obj["logo_url"] = obj.pop("images")
|
obj["logo_url"] = obj.pop("images")
|
||||||
obj["homepage_url"] = obj.pop("link")
|
obj["homepage_url"] = obj.pop("link")
|
||||||
payload.append(obj)
|
payload.append(obj)
|
||||||
|
|
||||||
client.recreate_collection(
|
client.recreate_collection(
|
||||||
collection_name=COLLECTION_NAME,
|
collection_name=COLLECTION_NAME,
|
||||||
vectors_config={
|
vectors_config=client.get_fastembed_vector_params(on_disk=True),
|
||||||
VECTOR_FIELD_NAME: models.VectorParams(
|
|
||||||
size=vector_size,
|
|
||||||
distance=models.Distance.COSINE,
|
|
||||||
on_disk=True,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
# Quantization is optional, but it can significantly reduce the memory usage
|
# Quantization is optional, but it can significantly reduce the memory usage
|
||||||
quantization_config=models.ScalarQuantization(
|
quantization_config=models.ScalarQuantization(
|
||||||
scalar=models.ScalarQuantizationConfig(
|
scalar=models.ScalarQuantizationConfig(
|
||||||
|
@ -69,15 +56,12 @@ def upload_embeddings():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
client.upload_collection(
|
client.add(
|
||||||
collection_name=COLLECTION_NAME,
|
collection_name=COLLECTION_NAME,
|
||||||
vectors={
|
documents=documents,
|
||||||
VECTOR_FIELD_NAME: vectors
|
metadata=payload,
|
||||||
},
|
ids=tqdm(range(len(payload))),
|
||||||
payload=payload,
|
parallel=0,
|
||||||
ids=None, # Vector ids will be assigned automatically
|
|
||||||
batch_size=64, # How many vectors will be uploaded in a single request?
|
|
||||||
parallel=10,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,24 +2,21 @@ from typing import List
|
||||||
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.http.models.models import Filter
|
from qdrant_client.http.models.models import Filter
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
from qdrant_demo.config import QDRANT_URL, QDRANT_API_KEY, VECTOR_FIELD_NAME
|
from qdrant_demo.config import QDRANT_URL, QDRANT_API_KEY
|
||||||
|
|
||||||
|
|
||||||
class NeuralSearcher:
|
class NeuralSearcher:
|
||||||
|
|
||||||
def __init__(self, collection_name: str):
|
def __init__(self, collection_name: str):
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
||||||
self.qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
self.qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
||||||
|
|
||||||
def search(self, text: str, filter_: dict = None) -> List[dict]:
|
def search(self, text: str, filter_: dict = None) -> List[dict]:
|
||||||
vector = self.model.encode(text).tolist()
|
hits = self.qdrant_client.query(
|
||||||
hits = self.qdrant_client.search(
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
query_vector=(VECTOR_FIELD_NAME, vector),
|
query_text=text,
|
||||||
query_filter=Filter(**filter_) if filter_ else None,
|
query_filter=Filter(**filter_) if filter_ else None,
|
||||||
limit=5
|
limit=5
|
||||||
)
|
)
|
||||||
return [hit.payload for hit in hits]
|
return [hit.metadata for hit in hits]
|
||||||
|
|
Loading…
Reference in New Issue