switch back to sentence-transformer model

This commit is contained in:
generall 2023-10-17 00:42:47 +02:00
parent 6318d5a653
commit 9c426ec984
5 changed files with 13 additions and 11 deletions

View File

@ -36,7 +36,7 @@ COPY ./poetry.lock /app
COPY --from=build-step /app/dist /app/static
RUN poetry install --no-interaction --no-ansi --no-root --without dev
RUN python -c 'from fastembed.embedding import DefaultEmbedding; DefaultEmbedding("BAAI/bge-small-en")'
RUN python -c 'from fastembed.embedding import DefaultEmbedding; DefaultEmbedding("sentence-transformers/all-MiniLM-L6-v2")'
# Finally copy the application source code and install root
COPY qdrant_demo /app/qdrant_demo

View File

@ -9,7 +9,6 @@ QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333/")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "text-demo")
EMBEDDINGS_MODEL = os.environ.get("EMBEDDINGS_MODEL", "BAAI/bge-small-en")
EMBEDDINGS_MODEL = os.environ.get("EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
VECTOR_FIELD_NAME = "fast-bge-small-en"
TEXT_FIELD_NAME = "short_description"
TEXT_FIELD_NAME = "document"

View File

@ -4,7 +4,7 @@ import pandas as pd
from qdrant_client import QdrantClient, models
from tqdm import tqdm
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, VECTOR_FIELD_NAME
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, EMBEDDINGS_MODEL
# Define the CSV file path and NPY file path
csv_file_path = os.path.join(DATA_DIR, "organizations.csv")
@ -16,7 +16,9 @@ def upload_embeddings():
api_key=QDRANT_API_KEY,
)
df = pd.read_csv(csv_file_path, nrows=1000)
client.set_model(EMBEDDINGS_MODEL)
df = pd.read_csv(csv_file_path)
documents = df['short_description'].tolist()
df.drop(columns=['short_description'], inplace=True)
metadata = df.to_dict('records')
@ -53,7 +55,7 @@ def upload_embeddings():
documents=documents,
metadata=metadata,
ids=tqdm(range(len(documents))),
parallel=0,
parallel=6,
)

View File

@ -2,11 +2,9 @@ import json
import os.path
from qdrant_client import QdrantClient, models
from qdrant_client.qdrant_fastembed import SUPPORTED_EMBEDDING_MODELS
from tqdm import tqdm
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, \
VECTOR_FIELD_NAME, EMBEDDINGS_MODEL
from qdrant_demo.config import DATA_DIR, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME, TEXT_FIELD_NAME, EMBEDDINGS_MODEL
def upload_embeddings():
@ -16,6 +14,8 @@ def upload_embeddings():
prefer_grpc=True,
)
client.set_model(EMBEDDINGS_MODEL)
payload_path = os.path.join(DATA_DIR, 'startups.json')
payload = []
documents = []

View File

@ -3,7 +3,7 @@ from typing import List
from qdrant_client import QdrantClient
from qdrant_client.http.models.models import Filter
from qdrant_demo.config import QDRANT_URL, QDRANT_API_KEY
from qdrant_demo.config import QDRANT_URL, QDRANT_API_KEY, EMBEDDINGS_MODEL
class NeuralSearcher:
@ -11,6 +11,7 @@ class NeuralSearcher:
def __init__(self, collection_name: str):
self.collection_name = collection_name
self.qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
self.qdrant_client.set_model(EMBEDDINGS_MODEL)
def search(self, text: str, filter_: dict = None) -> List[dict]:
hits = self.qdrant_client.query(